Source code for o2_sdk.websocket

"""WebSocket client for real-time O2 Exchange data streams.

Supports subscriptions for depth, orders, trades, balances, and nonce updates
with auto-reconnect and exponential backoff.

Lifecycle events (connected, disconnected, reconnecting, etc.) are available
via ``stream_lifecycle()`` so callers can re-sync state from REST after a
reconnect — critical for financial applications where missed messages can
cause incorrect position tracking.
"""

from __future__ import annotations

import asyncio
import contextlib
import enum
import json
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any

import websockets
from websockets.asyncio.client import ClientConnection

from .config import NetworkConfig
from .models import (
    BalanceUpdate,
    DepthUpdate,
    NonceUpdate,
    OrderUpdate,
    TradeUpdate,
)

logger = logging.getLogger("o2_sdk.websocket")


# ---------------------------------------------------------------------------
# Lifecycle events
# ---------------------------------------------------------------------------


class ConnectionState(enum.Enum):
    """WebSocket connection lifecycle states."""

    CONNECTED = "connected"
    DISCONNECTED = "disconnected"
    RECONNECTING = "reconnecting"
    RECONNECTED = "reconnected"
    CLOSED = "closed"  # terminal — max retries or explicit disconnect


@dataclass(frozen=True)
class ConnectionEvent:
    """Emitted on WebSocket lifecycle transitions.

    Attributes:
        state: The new connection state.
        attempt: Reconnect attempt number (0 when not reconnecting).
        message: Human-readable description of what happened.
    """

    state: ConnectionState
    attempt: int = 0
    message: str = ""


[docs] class O2WebSocket: """Async WebSocket client for O2 Exchange real-time data. Features: - Auto-reconnect with exponential backoff - Subscription tracking and automatic re-subscribe on reconnect - Per-subscriber message queues for safe concurrent access - Heartbeat ping/pong to detect silent disconnections - Lifecycle event channel for connection state awareness - Configurable max reconnect attempts Financial safety: After a reconnect the server replays the current order-book snapshot, but in-flight messages during the disconnect window are lost. Callers should subscribe to ``stream_lifecycle()`` and re-sync critical state (balances, open orders) from the REST API whenever they receive a ``RECONNECTED`` event. """ def __init__( self, config: NetworkConfig, ping_interval: float = 30.0, pong_timeout: float = 60.0, max_reconnect_attempts: int = 10, ): self._config = config self._ws: ClientConnection | None = None self._subscriptions: list[dict] = [] # Per-subscriber fan-out: each stream_*() call registers its own queue. # Key = action queue key (e.g. "depth", "orders"), value = list of queues. self._subscriber_queues: dict[str, list[asyncio.Queue[Any]]] = {} self._listener_task: asyncio.Task | None = None self._ping_task: asyncio.Task | None = None self._connected = False self._reconnect_delay = 1.0 self._max_reconnect_delay = 60.0 self._should_run = False self._ping_interval = ping_interval self._pong_timeout = pong_timeout self._max_reconnect_attempts = max_reconnect_attempts self._reconnect_attempts = 0 self._close_event = asyncio.Event() # ------------------------------------------------------------------ # Connection lifecycle # ------------------------------------------------------------------
[docs] async def connect(self) -> O2WebSocket: """Connect to the WebSocket endpoint.""" self._should_run = True self._close_event.clear() self._reconnect_attempts = 0 await self._do_connect() self._emit_lifecycle(ConnectionState.CONNECTED, message="Initial connection") self._ensure_background_tasks() return self
async def _do_connect(self) -> None: self._ws = await websockets.connect(self._config.ws_url) self._connected = True self._reconnect_delay = 1.0 self._reconnect_attempts = 0 logger.info("WebSocket connected to %s", self._config.ws_url) # Re-subscribe on reconnect for sub in self._subscriptions: await self._send(sub) def _ensure_background_tasks(self) -> None: """Start listener and ping tasks if they are not already running.""" if self._listener_task is None or self._listener_task.done(): self._listener_task = asyncio.create_task(self._listen()) if self._ping_task is None or self._ping_task.done(): self._ping_task = asyncio.create_task(self._ping_loop()) async def _ping_loop(self) -> None: """Send periodic pings and trigger reconnect on pong timeout. Liveness is determined by the WebSocket protocol-level pong response, NOT by whether application data has arrived. This is critical because a subscription may legitimately receive no data for long periods (e.g. ``stream_depth`` at precision 0 on a quiet market). Reconnecting in that case would be incorrect — the connection is alive, there's just nothing to report. """ try: while self._should_run and self._connected: await asyncio.sleep(self._ping_interval) if not self._should_run or not self._connected: return if not self._ws: return try: # ws.ping() returns a Future that resolves when the # protocol-level pong frame arrives. If the server is # alive, this completes well within pong_timeout. pong_waiter = await self._ws.ping() await asyncio.wait_for(pong_waiter, timeout=self._pong_timeout) logger.debug("Ping/pong OK") except asyncio.TimeoutError: logger.warning( "Pong timeout (%.1fs), triggering reconnect", self._pong_timeout, ) if self._ws: await self._ws.close() return except Exception: # Connection error during ping — let listener handle it return except asyncio.CancelledError: return async def _reconnect(self) -> None: """Reconnect with exponential backoff. CancelledError is intentionally NOT caught here so that ``disconnect()`` can interrupt a reconnect that's mid-backoff-sleep. """ self._connected = False self._emit_lifecycle(ConnectionState.DISCONNECTED, message="Connection lost") while self._should_run: if ( self._max_reconnect_attempts > 0 and self._reconnect_attempts >= self._max_reconnect_attempts ): logger.error( "Max reconnect attempts (%d) reached, stopping", self._max_reconnect_attempts, ) self._should_run = False self._close_event.set() self._emit_lifecycle( ConnectionState.CLOSED, message=f"Max reconnect attempts ({self._max_reconnect_attempts}) exhausted", ) self._signal_all_queues(None) return self._reconnect_attempts += 1 self._emit_lifecycle( ConnectionState.RECONNECTING, attempt=self._reconnect_attempts, message=f"Reconnecting in {self._reconnect_delay:.1f}s " f"(attempt {self._reconnect_attempts})", ) logger.info( "Reconnecting in %.1fs (attempt %d)...", self._reconnect_delay, self._reconnect_attempts, ) # CancelledError can interrupt this sleep — that's intentional. await asyncio.sleep(self._reconnect_delay) self._reconnect_delay = min(self._reconnect_delay * 2, self._max_reconnect_delay) try: await self._do_connect() # Restart ping task (listener is the caller, so it's still alive) if self._ping_task is None or self._ping_task.done(): self._ping_task = asyncio.create_task(self._ping_loop()) self._emit_lifecycle( ConnectionState.RECONNECTED, attempt=self._reconnect_attempts, message="Reconnected — consumers should re-sync from REST", ) return except asyncio.CancelledError: raise # Let disconnect() kill the reconnect loop except Exception as e: logger.error("Reconnect failed: %s", e)
[docs] async def disconnect(self) -> None: """Disconnect from the WebSocket. Designed to complete in bounded time even when the connection is in a broken state. The steps are ordered so that consumer generators are unblocked as early as possible: 1. Signal shutdown (``_should_run = False``) 2. Signal all subscriber queues *first* — unblocks any consumer sitting in ``queue.get()`` before we attempt the (potentially slow) WS close handshake. 3. Cancel internal tasks with a timeout. 4. Close the WS connection with a timeout — a broken TCP connection can hang for minutes without one. """ logger.info("WebSocket disconnecting") self._should_run = False self._connected = False # Emit the terminal CLOSED event BEFORE setting _close_event, so # stream_lifecycle() consumers see it before _wait_for_message() # short-circuits them. The close event + sentinel then unblock all # generators (including lifecycle after it yields the CLOSED event). self._emit_lifecycle(ConnectionState.CLOSED, message="Disconnected by client") self._close_event.set() self._signal_all_queues(None) # Cancel internal tasks with a bounded wait. If _reconnect() is # mid-backoff the cancel interrupts the sleep; the 5s timeout is # a safety net in case CancelledError is unexpectedly suppressed. for task_name, task in [ ("ping", self._ping_task), ("listener", self._listener_task), ]: if task and not task.done(): task.cancel() try: await asyncio.wait_for(task, timeout=5.0) except (asyncio.CancelledError, asyncio.TimeoutError): logger.warning("WS %s task did not exit cleanly", task_name) # Close the underlying WS connection. A half-open TCP socket can # hang for the OS keepalive timeout (minutes) without this guard. if self._ws: try: await asyncio.wait_for(self._ws.close(), timeout=5.0) except (asyncio.TimeoutError, Exception) as e: logger.warning("WS close timed out or errored: %s", e) self._ws = None
# ------------------------------------------------------------------ # Internal message routing # ------------------------------------------------------------------ async def _wait_for_message(self, queue: asyncio.Queue[Any]) -> Any | None: """Wait for the next queue message OR the close event, whichever first. Returns the message, or ``None`` if the close event fired (shutdown) AND no message was already queued. When both are ready (e.g. the lifecycle CLOSED event was pushed just before the close event was set), the queued message wins so consumers don't miss terminal events. """ # Fast path: drain any already-queued message before checking close. if not queue.empty(): return queue.get_nowait() if self._close_event.is_set(): return None get_task = asyncio.ensure_future(queue.get()) close_task = asyncio.ensure_future(self._close_event.wait()) try: done, pending = await asyncio.wait( [get_task, close_task], return_when=asyncio.FIRST_COMPLETED, ) for p in pending: p.cancel() with contextlib.suppress(asyncio.CancelledError): await p # Prefer queued messages over close signal — ensures terminal # events (like CLOSED) are delivered before the generator exits. if get_task in done: return get_task.result() return None except asyncio.CancelledError: get_task.cancel() close_task.cancel() with contextlib.suppress(asyncio.CancelledError): _ = await get_task # ensure task is finalized with contextlib.suppress(asyncio.CancelledError): _ = await close_task # ensure task is finalized raise def _signal_all_queues(self, sentinel: object) -> None: """Push a sentinel value to every subscriber queue. Data-stream generators use ``_wait_for_message`` which races against ``_close_event``, so the sentinel is belt-and-suspenders for them. However ``stream_lifecycle`` uses raw ``queue.get()`` and **depends** on the sentinel to exit. If the queue is full we must make room — dropping one stale event is acceptable to guarantee the sentinel (and the CLOSED event before it) are delivered. """ for queues in self._subscriber_queues.values(): for q in queues: try: q.put_nowait(sentinel) except asyncio.QueueFull: # Drop oldest item to make room for the sentinel. with contextlib.suppress(asyncio.QueueEmpty): q.get_nowait() with contextlib.suppress(asyncio.QueueFull): q.put_nowait(sentinel) def _emit_lifecycle( self, state: ConnectionState, attempt: int = 0, message: str = "", ) -> None: """Push a lifecycle event to all lifecycle subscribers.""" event = ConnectionEvent(state=state, attempt=attempt, message=message) logger.info("WS lifecycle: %s%s", state.value, message) if "lifecycle" in self._subscriber_queues: for q in self._subscriber_queues["lifecycle"]: try: q.put_nowait(event) except asyncio.QueueFull: # Drain and retry — lifecycle events must not be lost while not q.empty(): try: q.get_nowait() except asyncio.QueueEmpty: break q.put_nowait(event) async def _send(self, message: dict) -> None: if self._ws: logger.debug("WS send: %s", message.get("action", message)) await self._ws.send(json.dumps(message)) async def _listen(self) -> None: """Read messages from the WebSocket and dispatch to subscriber queues. On connection loss this method handles reconnection internally and continues reading — it does NOT return after a reconnect, which was the root cause of the "orphaned queue" hang in previous versions. """ try: while self._should_run: if not self._ws or not self._connected: await asyncio.sleep(0.1) continue try: raw = await self._ws.recv() data = json.loads(raw) action = data.get("action", "") self._dispatch(action, data) except websockets.ConnectionClosed: logger.warning("WebSocket connection closed") if self._should_run: await self._reconnect() # After reconnect, loop back to recv() on the new # connection instead of returning. _do_connect() # already set self._ws to the new connection. continue return except asyncio.CancelledError: return except Exception as e: logger.error("WebSocket listener error: %s", e) if self._should_run: await self._reconnect() continue return except asyncio.CancelledError: return def _dispatch(self, action: str, data: dict) -> None: """Route messages to all subscriber queues for the matching action type.""" key = self._action_to_queue_key(action) if key and key in self._subscriber_queues: for q in self._subscriber_queues[key]: try: q.put_nowait(data) except asyncio.QueueFull: logger.warning("Subscriber queue full for %s, dropping message", key) logger.debug( "WS dispatched %s -> %d %s subscriber(s)", action, len(self._subscriber_queues[key]), key, ) elif key is None and action: logger.warning("WS unhandled action: %s", action) def _action_to_queue_key(self, action: str) -> str | None: if action in ("subscribe_depth", "subscribe_depth_update"): return "depth" elif action == "subscribe_orders": return "orders" elif action == "subscribe_trades": return "trades" elif action == "subscribe_balances": return "balances" elif action == "subscribe_nonce": return "nonce" return None def _register_queue(self, key: str) -> asyncio.Queue[Any]: """Create and register a new subscriber queue for the given action key.""" q: asyncio.Queue[Any] = asyncio.Queue(maxsize=1000) if key not in self._subscriber_queues: self._subscriber_queues[key] = [] self._subscriber_queues[key].append(q) return q def _unregister_queue(self, key: str, q: asyncio.Queue) -> None: """Remove a subscriber queue when the consumer exits.""" if key in self._subscriber_queues: with contextlib.suppress(ValueError): self._subscriber_queues[key].remove(q) if not self._subscriber_queues[key]: del self._subscriber_queues[key] def _add_subscription(self, sub: dict) -> None: """Add a subscription for reconnect tracking, deduplicating by content.""" if sub not in self._subscriptions: self._subscriptions.append(sub) # ------------------------------------------------------------------ # Subscription methods # ------------------------------------------------------------------ async def stream_lifecycle(self) -> AsyncIterator[ConnectionEvent]: """Stream WebSocket connection lifecycle events. Yields ``ConnectionEvent`` objects whenever the connection state changes. Use this to detect reconnects and re-sync state from the REST API — messages received during the disconnect window are lost. Example:: async for event in ws.stream_lifecycle(): if event.state == ConnectionState.RECONNECTED: balances = await client.get_balances(account) orders = await client.get_orders(market, account) # ... rebuild local state ... elif event.state == ConnectionState.CLOSED: break # terminal, no more events """ queue = self._register_queue("lifecycle") try: while self._should_run: # Lifecycle uses raw queue.get() (not _wait_for_message) # because it must receive the terminal CLOSED event before # exiting. The CLOSED event is pushed to the queue before # _close_event is set, and _signal_all_queues(None) follows # as the final sentinel. msg = await queue.get() if msg is None: return yield msg finally: self._unregister_queue("lifecycle", queue)
[docs] async def stream_depth( self, market_id: str, wire_precision: str = "10" ) -> AsyncIterator[DepthUpdate]: """Subscribe to order book depth updates. Args: market_id: The market ID (hex string). wire_precision: Wire-format precision value (``10^level``). The caller (normally :meth:`O2Client.stream_depth`) converts the user-facing 1--18 index to the wire value. Default ``"10"`` (finest level). Note: Prefer :meth:`O2Client.stream_depth` which validates precision and resolves market pairs by name. """ sub = { "action": "subscribe_depth", "market_id": market_id, "precision": wire_precision, } self._add_subscription(sub) await self._send(sub) queue = self._register_queue("depth") try: while self._should_run: msg = await self._wait_for_message(queue) if msg is None: return if msg.get("market_id") == market_id: yield DepthUpdate.from_dict(msg) finally: self._unregister_queue("depth", queue)
[docs] async def stream_orders(self, identities: list[dict]) -> AsyncIterator[OrderUpdate]: """Subscribe to order updates for the given identities.""" sub = {"action": "subscribe_orders", "identities": identities} self._add_subscription(sub) await self._send(sub) queue = self._register_queue("orders") try: while self._should_run: msg = await self._wait_for_message(queue) if msg is None: return yield OrderUpdate.from_dict(msg) finally: self._unregister_queue("orders", queue)
[docs] async def stream_trades(self, market_id: str) -> AsyncIterator[TradeUpdate]: """Subscribe to trade updates for the given market.""" sub = {"action": "subscribe_trades", "market_id": market_id} self._add_subscription(sub) await self._send(sub) queue = self._register_queue("trades") try: while self._should_run: msg = await self._wait_for_message(queue) if msg is None: return if msg.get("market_id") == market_id: yield TradeUpdate.from_dict(msg) finally: self._unregister_queue("trades", queue)
[docs] async def stream_balances(self, identities: list[dict]) -> AsyncIterator[BalanceUpdate]: """Subscribe to balance updates for the given identities.""" sub = {"action": "subscribe_balances", "identities": identities} self._add_subscription(sub) await self._send(sub) queue = self._register_queue("balances") try: while self._should_run: msg = await self._wait_for_message(queue) if msg is None: return yield BalanceUpdate.from_dict(msg) finally: self._unregister_queue("balances", queue)
[docs] async def stream_nonce(self, identities: list[dict]) -> AsyncIterator[NonceUpdate]: """Subscribe to nonce updates for the given identities.""" sub = {"action": "subscribe_nonce", "identities": identities} self._add_subscription(sub) await self._send(sub) queue = self._register_queue("nonce") try: while self._should_run: msg = await self._wait_for_message(queue) if msg is None: return yield NonceUpdate.from_dict(msg) finally: self._unregister_queue("nonce", queue)
# ------------------------------------------------------------------ # Unsubscribe methods # ------------------------------------------------------------------
[docs] async def unsubscribe_depth(self, market_id: str) -> None: await self._send({"action": "unsubscribe_depth", "market_id": market_id}) self._subscriptions = [ s for s in self._subscriptions if not (s.get("action") == "subscribe_depth" and s.get("market_id") == market_id) ]
[docs] async def unsubscribe_orders(self) -> None: await self._send({"action": "unsubscribe_orders"}) self._subscriptions = [ s for s in self._subscriptions if s.get("action") != "subscribe_orders" ]
[docs] async def unsubscribe_trades(self, market_id: str) -> None: await self._send({"action": "unsubscribe_trades", "market_id": market_id}) self._subscriptions = [ s for s in self._subscriptions if not (s.get("action") == "subscribe_trades" and s.get("market_id") == market_id) ]
[docs] async def unsubscribe_balances(self, identities: list[dict]) -> None: await self._send({"action": "unsubscribe_balances", "identities": identities}) self._subscriptions = [ s for s in self._subscriptions if s.get("action") != "subscribe_balances" ]
[docs] async def unsubscribe_nonce(self, identities: list[dict]) -> None: await self._send({"action": "unsubscribe_nonce", "identities": identities}) self._subscriptions = [ s for s in self._subscriptions if s.get("action") != "subscribe_nonce" ]