diff --git a/internal/grid/connection.go b/internal/grid/connection.go index 73875c6b8..352eba4e2 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -707,26 +707,15 @@ func (c *Connection) connect() { if debugPrint { fmt.Println(c.Local, "Connected Waiting for Messages") } - c.updateState(StateConnected) - go c.handleMessages(c.ctx, conn) - // Monitor state changes and reconnect if needed. - c.connChange.L.Lock() - for { - newState := c.State() - if newState != StateConnected { - c.connChange.L.Unlock() - if newState == StateShutdown { - conn.Close() - return - } - if debugPrint { - fmt.Println(c.Local, "Disconnected") - } - // Reconnect - break - } - // Unlock and wait for state change. - c.connChange.Wait() + // Handle messages... + c.handleMessages(c.ctx, conn) + // Reconnect unless we are shutting down (debug only). + if c.State() == StateShutdown { + conn.Close() + return + } + if debugPrint { + fmt.Println(c.Local, "Disconnected. Attempting to reconnect.") } } } @@ -818,7 +807,7 @@ func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req conn rid := uuid.UUID(req.ID) c.remoteID = &rid - c.updateState(StateConnected) + // Handle incoming messages until disconnect. c.handleMessages(ctx, conn) return nil } @@ -867,12 +856,36 @@ func (c *Connection) updateState(s State) { c.connChange.Broadcast() } +// monitorState will monitor the state of the connection and close the net.Conn if it changes. +func (c *Connection) monitorState(conn net.Conn, cancel context.CancelCauseFunc) { + c.connChange.L.Lock() + defer c.connChange.L.Unlock() + for { + newState := c.State() + if newState != StateConnected { + conn.Close() + cancel(ErrDisconnected) + return + } + // Unlock and wait for state change. + c.connChange.Wait() + } +} + // handleMessages will handle incoming messages on conn. // caller *must* hold reconnectMu. func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { + c.updateState(StateConnected) + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(ErrDisconnected) + + // This will ensure that is something asks to disconnect and we are blocked on reads/writes + // the connection will be closed and readers/writers will unblock. + go c.monitorState(conn, cancel) + c.handleMsgWg.Add(2) c.reconnectMu.Unlock() - ctx, cancel := context.WithCancelCause(ctx) + // Read goroutine go func() { defer func() { @@ -1034,7 +1047,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { lastPongTime := time.Unix(lastPong, 0) if d := time.Since(lastPongTime); d > connPingInterval*2 { logger.LogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond))) - cancel(ErrDisconnected) return } } @@ -1084,14 +1096,12 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { err := wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend) if err != nil { logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err)) - cancel(ErrDisconnected) return } PutByteBuffer(toSend) _, err = buf.WriteTo(conn) if err != nil { logger.LogIf(ctx, fmt.Errorf("ws write: %w", err)) - cancel(ErrDisconnected) return } continue @@ -1109,7 +1119,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { toSend, err = m.MarshalMsg(toSend) if err != nil { logger.LogIf(ctx, fmt.Errorf("msg.MarshalMsg: %w", err)) - cancel(ErrDisconnected) return } // Append as byte slices. @@ -1126,14 +1135,12 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { err = wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend) if err != nil { logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err)) - cancel(ErrDisconnected) return } - // Tosend is our local buffer, so we can reuse it. + // buf is our local buffer, so we can reuse it. _, err = buf.WriteTo(conn) if err != nil { logger.LogIf(ctx, fmt.Errorf("ws write: %w", err)) - cancel(ErrDisconnected) return } diff --git a/internal/grid/handlers.go b/internal/grid/handlers.go index 82f7efdb2..902f55312 100644 --- a/internal/grid/handlers.go +++ b/internal/grid/handlers.go @@ -691,6 +691,9 @@ func (h *StreamTypeHandler[Payload, Req, Resp]) Call(ctx context.Context, c Stre if h.InCapacity > 0 { reqT = make(chan Req) // Request handler + if stream.Requests == nil { + return nil, fmt.Errorf("internal error: stream request channel nil") + } go func() { defer close(stream.Requests) for req := range reqT { diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index a04122c30..3ea50ceae 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -533,7 +533,9 @@ func (m *muxClient) closeLocked() { if m.closed { return } - close(m.respWait) - m.respWait = nil + if m.respWait != nil { + close(m.respWait) + m.respWait = nil + } m.closed = true } diff --git a/internal/grid/muxserver.go b/internal/grid/muxserver.go index fd7096f41..907722462 100644 --- a/internal/grid/muxserver.go +++ b/internal/grid/muxserver.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "runtime/debug" "sync" "sync/atomic" "time" @@ -138,7 +139,8 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea } if r := recover(); r != nil { logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r)) - err := RemoteErr(fmt.Sprintf("panic: %v", r)) + debug.PrintStack() + err := RemoteErr(fmt.Sprintf("remote call panic: %v", r)) handlerErr = &err } if debugPrint { @@ -244,8 +246,10 @@ func (m *muxServer) message(msg message) { if len(msg.Payload) > 0 { logger.LogIf(m.ctx, fmt.Errorf("muxServer: EOF message with payload")) } - close(m.inbound) - m.inbound = nil + if m.inbound != nil { + close(m.inbound) + m.inbound = nil + } return }