package websocket import ( "github.com/gorilla/websocket" "go.uber.org/zap" "google.golang.org/protobuf/proto" "io" "time" ) type writer struct { // conn is the websocket connection that this writer is responsible for writing on. conn *websocket.Conn // channel is the channel used to receive server messages to be sent to the client. // When it receives a SocketClosed, the process on the sending end promises not to send any further messages, as the writer will close it right after. channel chan ServerCommand // readNotifications is the channel used to receive pings when the reader receives a message, so that a ping will be sent out before the reader is ready to time out. // When it is closed, the reader has shut down. readNotifications <-chan time.Time // timer is the timer used to send pings when the reader is close to timing out, to make sure the other end of the connection is still listening. timer *time.Timer // nextPingAt is the time after which the next ping will be sent if the timer is ticking. It will be .IsZero() if the timer is not ticking. nextPingAt time.Time // logger is the logger used to record the state of the writer, primarily in Debug level. logger *zap.Logger } // IsTimerTicking returns true if the writer's timer is running. func (w *writer) isTimerTicking() bool { return !w.nextPingAt.IsZero() } // sendNextPingAt starts the timer if it's not running, and func (w *writer) sendNextPingAt(nextPing time.Time) { if w.nextPingAt.IsZero() { // Timer is not running, so set the next ping time and start it. w.nextPingAt = nextPing w.timer.Reset(time.Until(nextPing)) } else if w.nextPingAt.Before(nextPing) { // Timer's already running, so leave it be, but update the next ping time. w.nextPingAt = nextPing } else { // The timer is already set to a time after the incoming time. // It's extremely unlikely for this empty branch to ever be reached. } } // act is the function responsible for actually doing the writing. func (w *writer) act() { defer w.gracefulShutdown() w.logger.Debug("Starting up") w.timer = time.NewTimer(PingDelay) w.nextPingAt = time.Now().Add(PingDelay) for { select { case readAt, open := <-w.readNotifications: if open { nextPingAt := readAt.Add(PingDelay) w.logger.Debug("Received reader read, extending ping timer", zap.Time("nextPingAt", nextPingAt)) w.sendNextPingAt(nextPingAt) } else { w.logger.Debug("Received reader close, shutting down") w.readNotifications = nil // bye bye, we'll graceful shutdown because we deferred it return } case raw := <-w.channel: switch msg := raw.(type) { case SocketClosed: w.logger.Debug("Received close message, forwarding and shutting down", zap.Object("msg", msg)) w.sendClose(msg) // bye bye, we'll graceful shutdown because we deferred it return default: w.logger.Debug("Received message, forwarding", zap.Object("msg", msg)) w.send(msg) } case <-w.timer.C: now := time.Now() if now.After(w.nextPingAt) { // We successfully passed the time when a ping should be sent! Let's send it! w.sendPing() // The timer doesn't need to be reactivated right now, so just zero out the next ping time. w.nextPingAt = time.Time{} } else { // It's not time to send the ping yet - we got more reads in the meantime. Restart the timer with the new time-until-next-ping. w.timer.Reset(w.nextPingAt.Sub(now)) } } w.logger.Debug("Awakening handled, resuming listening") } } // send actually transmits a ServerCommand to the client according to the protocol. func (w *writer) send(msg ServerCommand) { w.logger.Debug("Marshaling command as protobuf", zap.Object("msg", msg)) marshaled, err := proto.Marshal(msg.ToServerPB()) if err != nil { w.logger.Error("Error while marshaling to protobuf", zap.Error(err)) return } writeDeadline := time.Now().Add(WriteTimeLimit) w.logger.Debug("Setting deadline to write command", zap.Time("writeDeadline", writeDeadline)) err = w.conn.SetWriteDeadline(writeDeadline) if err != nil { w.logger.Error("Error while setting write deadline", zap.Time("writeDeadline", writeDeadline), zap.Object("msg", msg), zap.Error(err)) } w.logger.Debug("Opening message writer to send command", zap.Object("msg", msg)) writer, err := w.conn.NextWriter(websocket.TextMessage) if err != nil { w.logger.Error("Error while getting writer from connection", zap.Error(err)) return } defer func(writer io.WriteCloser) { w.logger.Debug("Closing message writer to send command") err := writer.Close() if err != nil { w.logger.Error("Error while closing writer to send command", zap.Error(err)) } w.logger.Debug("Command sent") }(writer) _, err = writer.Write(marshaled) if err != nil { w.logger.Error("Error while writing marshaled protobuf to connection", zap.Error(err)) return } // Deferred close happens now } // sendClose sends a close message on the websocket connection, but does not actually close the connection. // It does, however, close the incoming message channel to the writer. func (w *writer) sendClose(msg SocketClosed) { w.logger.Debug("Shutting down the writer channel") close(w.channel) w.channel = nil w.logger.Debug("Writing close message") err := w.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(msg.Code, msg.Text), time.Now().Add(ControlTimeLimit)) if err != nil { w.logger.Warn("Error while sending close", zap.Error(err)) } } // sendPing sends a ping message on the websocket connection. The content is arbitrary. func (w *writer) sendPing() { w.logger.Debug("Sending ping") err := w.conn.WriteControl(websocket.PingMessage, []byte(PingData), time.Now().Add(ControlTimeLimit)) if err != nil { w.logger.Error("Error while sending ping", zap.Error(err)) } } // gracefulShutdown causes the writer to wait for the close handshake to finish and then shut down. // It waits for the reader's readNotifications to close, indicating that it has also shut down, and for the channel to // receive a SocketClosed message indicating that the main process has shut down. // During this time, the writer ignores all other messages from the channel and sends no pings. func (w *writer) gracefulShutdown() { defer w.finalShutdown() // If the ping timer is still running, stop it and then close it. if w.isTimerTicking() && !w.timer.Stop() { <-w.timer.C } w.timer = nil w.nextPingAt = time.Time{} w.logger.Debug("Waiting for all channels to shut down") for { if w.channel == nil && w.readNotifications == nil { w.logger.Debug("All channels closed, beginning final shutdown") // all done, we outta here, let the defer pick up the final shutdown return } select { case _, open := <-w.readNotifications: if !open { w.logger.Debug("Received reader close while shutting down") w.readNotifications = nil } case raw := <-w.channel: switch msg := raw.(type) { case SocketClosed: w.logger.Debug("Received close message from channel while shutting down, forwarding", zap.Object("msg", msg)) w.sendClose(msg) default: w.logger.Debug("Ignoring non-close message while shutting down", zap.Object("msg", msg)) } } } } // finalShutdown closes the socket and finishes cleanup. func (w *writer) finalShutdown() { w.logger.Debug("Closing WebSocket connection") err := w.conn.Close() if err != nil { w.logger.Error("Received an error while closing", zap.Error(err)) } w.logger.Debug("Shut down") }