You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
197 lines
7.3 KiB
197 lines
7.3 KiB
package websocket
|
|
|
|
import (
|
|
"github.com/gorilla/websocket"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/protobuf/proto"
|
|
"io"
|
|
"time"
|
|
)
|
|
|
|
type writer struct {
|
|
// conn is the websocket connection that this writer is responsible for writing on.
|
|
conn *websocket.Conn
|
|
// channel is the channel used to receive server messages to be sent to the client.
|
|
// When it receives a SocketClosed, the process on the sending end promises not to send any further messages, as the writer will close it right after.
|
|
channel chan ServerCommand
|
|
// readNotifications is the channel used to receive pings when the reader receives a message, so that a ping will be sent out before the reader is ready to time out.
|
|
// When it is closed, the reader has shut down.
|
|
readNotifications <-chan time.Time
|
|
// timer is the timer used to send pings when the reader is close to timing out, to make sure the other end of the connection is still listening.
|
|
timer *time.Timer
|
|
// nextPingAt is the time after which the next ping will be sent if the timer is ticking. It will be .IsZero() if the timer is not ticking.
|
|
nextPingAt time.Time
|
|
// logger is the logger used to record the state of the writer, primarily in Debug level.
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// IsTimerTicking returns true if the writer's timer is running.
|
|
func (w *writer) isTimerTicking() bool {
|
|
return !w.nextPingAt.IsZero()
|
|
}
|
|
|
|
// sendNextPingAt starts the timer if it's not running, and
|
|
func (w *writer) sendNextPingAt(nextPing time.Time) {
|
|
if w.nextPingAt.IsZero() {
|
|
// Timer is not running, so set the next ping time and start it.
|
|
w.nextPingAt = nextPing
|
|
w.timer.Reset(time.Until(nextPing))
|
|
} else if w.nextPingAt.Before(nextPing) {
|
|
// Timer's already running, so leave it be, but update the next ping time.
|
|
w.nextPingAt = nextPing
|
|
} else {
|
|
// The timer is already set to a time after the incoming time.
|
|
// It's extremely unlikely for this empty branch to ever be reached.
|
|
}
|
|
}
|
|
|
|
// act is the function responsible for actually doing the writing.
|
|
func (w *writer) act() {
|
|
defer w.gracefulShutdown()
|
|
w.logger.Debug("Starting up")
|
|
w.timer = time.NewTimer(PingDelay)
|
|
w.nextPingAt = time.Now().Add(PingDelay)
|
|
for {
|
|
select {
|
|
case readAt, open := <-w.readNotifications:
|
|
if open {
|
|
nextPingAt := readAt.Add(PingDelay)
|
|
w.logger.Debug("Received reader read, extending ping timer", zap.Time("nextPingAt", nextPingAt))
|
|
w.sendNextPingAt(nextPingAt)
|
|
} else {
|
|
w.logger.Debug("Received reader close, shutting down")
|
|
w.readNotifications = nil
|
|
// bye bye, we'll graceful shutdown because we deferred it
|
|
return
|
|
}
|
|
case raw := <-w.channel:
|
|
switch msg := raw.(type) {
|
|
case SocketClosed:
|
|
w.logger.Debug("Received close message, forwarding and shutting down", zap.Object("msg", msg))
|
|
w.sendClose(msg)
|
|
// bye bye, we'll graceful shutdown because we deferred it
|
|
return
|
|
default:
|
|
w.logger.Debug("Received message, forwarding", zap.Object("msg", msg))
|
|
w.send(msg)
|
|
}
|
|
case <-w.timer.C:
|
|
now := time.Now()
|
|
if now.After(w.nextPingAt) {
|
|
// We successfully passed the time when a ping should be sent! Let's send it!
|
|
w.sendPing()
|
|
// The timer doesn't need to be reactivated right now, so just zero out the next ping time.
|
|
w.nextPingAt = time.Time{}
|
|
} else {
|
|
// It's not time to send the ping yet - we got more reads in the meantime. Restart the timer with the new time-until-next-ping.
|
|
w.timer.Reset(w.nextPingAt.Sub(now))
|
|
}
|
|
}
|
|
w.logger.Debug("Awakening handled, resuming listening")
|
|
}
|
|
}
|
|
|
|
// send actually transmits a ServerCommand to the client according to the protocol.
|
|
func (w *writer) send(msg ServerCommand) {
|
|
w.logger.Debug("Marshaling command as protobuf", zap.Object("msg", msg))
|
|
marshaled, err := proto.Marshal(msg.ToServerPB())
|
|
if err != nil {
|
|
w.logger.Error("Error while marshaling to protobuf", zap.Error(err))
|
|
return
|
|
}
|
|
writeDeadline := time.Now().Add(WriteTimeLimit)
|
|
w.logger.Debug("Setting deadline to write command", zap.Time("writeDeadline", writeDeadline))
|
|
err = w.conn.SetWriteDeadline(writeDeadline)
|
|
if err != nil {
|
|
w.logger.Error("Error while setting write deadline", zap.Time("writeDeadline", writeDeadline), zap.Object("msg", msg), zap.Error(err))
|
|
}
|
|
w.logger.Debug("Opening message writer to send command", zap.Object("msg", msg))
|
|
writer, err := w.conn.NextWriter(websocket.TextMessage)
|
|
if err != nil {
|
|
w.logger.Error("Error while getting writer from connection", zap.Error(err))
|
|
return
|
|
}
|
|
defer func(writer io.WriteCloser) {
|
|
w.logger.Debug("Closing message writer to send command")
|
|
err := writer.Close()
|
|
if err != nil {
|
|
w.logger.Error("Error while closing writer to send command", zap.Error(err))
|
|
}
|
|
w.logger.Debug("Command sent")
|
|
}(writer)
|
|
_, err = writer.Write(marshaled)
|
|
if err != nil {
|
|
w.logger.Error("Error while writing marshaled protobuf to connection", zap.Error(err))
|
|
return
|
|
}
|
|
// Deferred close happens now
|
|
}
|
|
|
|
// sendClose sends a close message on the websocket connection, but does not actually close the connection.
|
|
// It does, however, close the incoming message channel to the writer.
|
|
func (w *writer) sendClose(msg SocketClosed) {
|
|
w.logger.Debug("Shutting down the writer channel")
|
|
close(w.channel)
|
|
w.channel = nil
|
|
w.logger.Debug("Writing close message")
|
|
err := w.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(msg.Code, msg.Text), time.Now().Add(ControlTimeLimit))
|
|
if err != nil {
|
|
w.logger.Warn("Error while sending close", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// sendPing sends a ping message on the websocket connection. The content is arbitrary.
|
|
func (w *writer) sendPing() {
|
|
w.logger.Debug("Sending ping")
|
|
err := w.conn.WriteControl(websocket.PingMessage, []byte(PingData), time.Now().Add(ControlTimeLimit))
|
|
if err != nil {
|
|
w.logger.Error("Error while sending ping", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// gracefulShutdown causes the writer to wait for the close handshake to finish and then shut down.
|
|
// It waits for the reader's readNotifications to close, indicating that it has also shut down, and for the channel to
|
|
// receive a SocketClosed message indicating that the main process has shut down.
|
|
// During this time, the writer ignores all other messages from the channel and sends no pings.
|
|
func (w *writer) gracefulShutdown() {
|
|
defer w.finalShutdown()
|
|
// If the ping timer is still running, stop it and then close it.
|
|
if w.isTimerTicking() && !w.timer.Stop() {
|
|
<-w.timer.C
|
|
}
|
|
w.timer = nil
|
|
w.nextPingAt = time.Time{}
|
|
w.logger.Debug("Waiting for all channels to shut down")
|
|
for {
|
|
if w.channel == nil && w.readNotifications == nil {
|
|
w.logger.Debug("All channels closed, beginning final shutdown")
|
|
// all done, we outta here, let the defer pick up the final shutdown
|
|
return
|
|
}
|
|
select {
|
|
case _, open := <-w.readNotifications:
|
|
if !open {
|
|
w.logger.Debug("Received reader close while shutting down")
|
|
w.readNotifications = nil
|
|
}
|
|
case raw := <-w.channel:
|
|
switch msg := raw.(type) {
|
|
case SocketClosed:
|
|
w.logger.Debug("Received close message from channel while shutting down, forwarding", zap.Object("msg", msg))
|
|
w.sendClose(msg)
|
|
default:
|
|
w.logger.Debug("Ignoring non-close message while shutting down", zap.Object("msg", msg))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// finalShutdown closes the socket and finishes cleanup.
|
|
func (w *writer) finalShutdown() {
|
|
w.logger.Debug("Closing WebSocket connection")
|
|
err := w.conn.Close()
|
|
if err != nil {
|
|
w.logger.Error("Received an error while closing", zap.Error(err))
|
|
}
|
|
w.logger.Debug("Shut down")
|
|
}
|
|
|