1use futures_util::{SinkExt, StreamExt};
10use serde_json::json;
11use std::pin::Pin;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use std::time::{Duration, Instant};
16use tokio::sync::broadcast;
17use tokio::sync::mpsc;
18use tokio::sync::Mutex;
19use tokio_stream::Stream;
20use tokio_tungstenite::tungstenite::Message as WsMsg;
21
22use crate::errors::O2Error;
23use crate::models::*;
24
25type WsSink = futures_util::stream::SplitSink<
26 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
27 WsMsg,
28>;
29
30type WsStream = futures_util::stream::SplitStream<
31 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
32>;
33
34#[derive(Debug, Clone)]
36pub struct WsConfig {
37 pub base_delay: Duration,
39 pub max_delay: Duration,
41 pub max_attempts: usize,
43 pub ping_interval: Duration,
45 pub pong_timeout: Duration,
47}
48
49impl Default for WsConfig {
50 fn default() -> Self {
51 Self {
52 base_delay: Duration::from_secs(1),
53 max_delay: Duration::from_secs(60),
54 max_attempts: 10,
55 ping_interval: Duration::from_secs(30),
56 pong_timeout: Duration::from_secs(60),
57 }
58 }
59}
60
61pub struct TypedStream<T> {
70 rx: mpsc::UnboundedReceiver<Result<T, O2Error>>,
71}
72
73impl<T> Stream for TypedStream<T> {
74 type Item = Result<T, O2Error>;
75
76 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
77 self.rx.poll_recv(cx)
78 }
79}
80
81#[derive(Debug, Clone)]
90pub struct DepthPrecision(String);
91
92impl DepthPrecision {
93 pub fn new(level: u64) -> Result<Self, O2Error> {
95 if !(1..=18).contains(&level) {
96 return Err(O2Error::InvalidRequest(format!(
97 "Invalid depth precision {level}. Must be an integer in range 1-18."
98 )));
99 }
100 Ok(Self(10u64.pow(level as u32).to_string()))
101 }
102
103 pub fn as_str(&self) -> &str {
105 &self.0
106 }
107}
108
109#[non_exhaustive]
111#[derive(Debug, Clone, PartialEq, Eq)]
112pub enum WsLifecycleEvent {
113 Reconnecting { attempt: usize, delay: Duration },
114 Reconnected { attempts: usize },
115 Disconnected { reason: String, final_: bool },
116}
117
118struct WsInner {
120 sink: Option<WsSink>,
121 subscriptions: Vec<serde_json::Value>,
122 depth_senders: Vec<mpsc::UnboundedSender<Result<DepthUpdate, O2Error>>>,
123 orders_senders: Vec<mpsc::UnboundedSender<Result<OrderUpdate, O2Error>>>,
124 trades_senders: Vec<mpsc::UnboundedSender<Result<TradeUpdate, O2Error>>>,
125 balances_senders: Vec<mpsc::UnboundedSender<Result<BalanceUpdate, O2Error>>>,
126 nonce_senders: Vec<mpsc::UnboundedSender<Result<NonceUpdate, O2Error>>>,
127}
128
129impl WsInner {
130 fn new() -> Self {
131 Self {
132 sink: None,
133 subscriptions: Vec::new(),
134 depth_senders: Vec::new(),
135 orders_senders: Vec::new(),
136 trades_senders: Vec::new(),
137 balances_senders: Vec::new(),
138 nonce_senders: Vec::new(),
139 }
140 }
141
142 fn prune_closed_senders(&mut self) {
144 self.depth_senders.retain(|s| !s.is_closed());
145 self.orders_senders.retain(|s| !s.is_closed());
146 self.trades_senders.retain(|s| !s.is_closed());
147 self.balances_senders.retain(|s| !s.is_closed());
148 self.nonce_senders.retain(|s| !s.is_closed());
149 }
150
151 fn close_all_senders(&mut self) {
153 self.depth_senders.clear();
154 self.orders_senders.clear();
155 self.trades_senders.clear();
156 self.balances_senders.clear();
157 self.nonce_senders.clear();
158 }
159
160 fn close_all_senders_with_error(&mut self, msg: &str) {
162 for tx in &self.depth_senders {
163 let _ = tx.send(Err(O2Error::WebSocketDisconnected(msg.to_string())));
164 }
165 for tx in &self.orders_senders {
166 let _ = tx.send(Err(O2Error::WebSocketDisconnected(msg.to_string())));
167 }
168 for tx in &self.trades_senders {
169 let _ = tx.send(Err(O2Error::WebSocketDisconnected(msg.to_string())));
170 }
171 for tx in &self.balances_senders {
172 let _ = tx.send(Err(O2Error::WebSocketDisconnected(msg.to_string())));
173 }
174 for tx in &self.nonce_senders {
175 let _ = tx.send(Err(O2Error::WebSocketDisconnected(msg.to_string())));
176 }
177 self.close_all_senders();
178 }
179}
180
181pub struct O2WebSocket {
186 url: String,
187 config: WsConfig,
188 inner: Arc<Mutex<WsInner>>,
189 connected: Arc<AtomicBool>,
190 should_run: Arc<AtomicBool>,
191 last_pong: Arc<Mutex<Instant>>,
192 lifecycle_tx: Arc<broadcast::Sender<WsLifecycleEvent>>,
193 reader_handle: Option<tokio::task::JoinHandle<()>>,
194 ping_handle: Option<tokio::task::JoinHandle<()>>,
195}
196
197impl O2WebSocket {
198 pub async fn connect(url: &str) -> Result<Self, O2Error> {
200 Self::connect_with_config(url, WsConfig::default()).await
201 }
202
203 pub async fn connect_with_config(url: &str, config: WsConfig) -> Result<Self, O2Error> {
205 let inner = Arc::new(Mutex::new(WsInner::new()));
206 let connected = Arc::new(AtomicBool::new(false));
207 let should_run = Arc::new(AtomicBool::new(true));
208 let last_pong = Arc::new(Mutex::new(Instant::now()));
209 let lifecycle_tx = Arc::new(broadcast::channel(64).0);
210
211 let mut ws = Self {
212 url: url.to_string(),
213 config,
214 inner,
215 connected,
216 should_run,
217 last_pong,
218 lifecycle_tx,
219 reader_handle: None,
220 ping_handle: None,
221 };
222
223 ws.do_connect().await?;
224 Ok(ws)
225 }
226
227 async fn do_connect(&mut self) -> Result<(), O2Error> {
228 let (ws_stream, _) = tokio_tungstenite::connect_async(&self.url).await?;
229 let (sink, stream) = ws_stream.split();
230
231 {
232 let mut guard = self.inner.lock().await;
233 guard.sink = Some(sink);
234 }
235
236 self.connected.store(true, Ordering::SeqCst);
237 *self.last_pong.lock().await = Instant::now();
238
239 {
241 let mut guard = self.inner.lock().await;
242 let subs = guard.subscriptions.clone();
243 if let Some(ref mut sink) = guard.sink {
244 for sub in &subs {
246 let text = serde_json::to_string(sub).unwrap_or_default();
247 let _ = sink.send(WsMsg::Text(text)).await;
248 }
249 }
250 }
251
252 let inner_clone = self.inner.clone();
254 let connected_clone = self.connected.clone();
255 let should_run_clone = self.should_run.clone();
256 let last_pong_clone = self.last_pong.clone();
257 let url_clone = self.url.clone();
258 let config_clone = self.config.clone();
259 let lifecycle_tx_clone = self.lifecycle_tx.clone();
260
261 let reader_handle = tokio::spawn(async move {
262 Self::read_loop(
263 stream,
264 inner_clone.clone(),
265 connected_clone.clone(),
266 should_run_clone.clone(),
267 last_pong_clone.clone(),
268 )
269 .await;
270
271 if should_run_clone.load(Ordering::SeqCst) {
273 connected_clone.store(false, Ordering::SeqCst);
274 Self::reconnect_loop(
275 &url_clone,
276 &config_clone,
277 inner_clone,
278 connected_clone,
279 should_run_clone,
280 last_pong_clone,
281 lifecycle_tx_clone,
282 )
283 .await;
284 }
285 });
286 self.reader_handle = Some(reader_handle);
287
288 let inner_ping = self.inner.clone();
290 let connected_ping = self.connected.clone();
291 let should_run_ping = self.should_run.clone();
292 let last_pong_ping = self.last_pong.clone();
293 let ping_interval = self.config.ping_interval;
294 let pong_timeout = self.config.pong_timeout;
295
296 let ping_handle = tokio::spawn(async move {
297 Self::ping_loop(
298 inner_ping,
299 connected_ping,
300 should_run_ping,
301 last_pong_ping,
302 ping_interval,
303 pong_timeout,
304 )
305 .await;
306 });
307 self.ping_handle = Some(ping_handle);
308
309 Ok(())
310 }
311
312 async fn read_loop(
313 mut stream: WsStream,
314 inner: Arc<Mutex<WsInner>>,
315 connected: Arc<AtomicBool>,
316 should_run: Arc<AtomicBool>,
317 last_pong: Arc<Mutex<Instant>>,
318 ) {
319 while should_run.load(Ordering::SeqCst) {
320 let msg = match stream.next().await {
321 Some(Ok(m)) => m,
322 Some(Err(_)) => break,
323 None => break,
324 };
325
326 match msg {
327 WsMsg::Text(text) => {
328 let text = text.to_string();
329 let parsed: serde_json::Value = match serde_json::from_str(&text) {
330 Ok(v) => v,
331 Err(_) => continue,
332 };
333
334 let action = parsed.get("action").and_then(|a| a.as_str()).unwrap_or("");
335
336 let mut guard = inner.lock().await;
337 guard.prune_closed_senders();
338
339 match action {
340 "subscribe_depth" | "subscribe_depth_update" => {
341 if let Ok(update) = serde_json::from_value::<DepthUpdate>(parsed) {
342 for tx in &guard.depth_senders {
343 let _ = tx.send(Ok(update.clone()));
344 }
345 }
346 }
347 "subscribe_orders" => {
348 if let Ok(update) = serde_json::from_value::<OrderUpdate>(parsed) {
349 for tx in &guard.orders_senders {
350 let _ = tx.send(Ok(update.clone()));
351 }
352 }
353 }
354 "subscribe_trades" => {
355 if let Ok(update) = serde_json::from_value::<TradeUpdate>(parsed) {
356 for tx in &guard.trades_senders {
357 let _ = tx.send(Ok(update.clone()));
358 }
359 }
360 }
361 "subscribe_balances" => {
362 if let Ok(update) = serde_json::from_value::<BalanceUpdate>(parsed) {
363 for tx in &guard.balances_senders {
364 let _ = tx.send(Ok(update.clone()));
365 }
366 }
367 }
368 "subscribe_nonce" => {
369 if let Ok(update) = serde_json::from_value::<NonceUpdate>(parsed) {
370 for tx in &guard.nonce_senders {
371 let _ = tx.send(Ok(update.clone()));
372 }
373 }
374 }
375 _ => {}
376 }
377 }
378 WsMsg::Pong(_) => {
379 *last_pong.lock().await = Instant::now();
380 }
381 WsMsg::Close(_) => {
382 connected.store(false, Ordering::SeqCst);
383 break;
384 }
385 WsMsg::Ping(data) => {
386 let mut guard = inner.lock().await;
388 if let Some(ref mut sink) = guard.sink {
389 let _ = sink.send(WsMsg::Pong(data)).await;
390 }
391 }
392 _ => {}
393 }
394 }
395 }
396
397 async fn ping_loop(
398 inner: Arc<Mutex<WsInner>>,
399 connected: Arc<AtomicBool>,
400 should_run: Arc<AtomicBool>,
401 last_pong: Arc<Mutex<Instant>>,
402 ping_interval: Duration,
403 pong_timeout: Duration,
404 ) {
405 let mut interval = tokio::time::interval(ping_interval);
406 interval.tick().await; while should_run.load(Ordering::SeqCst) {
409 interval.tick().await;
410
411 if !connected.load(Ordering::SeqCst) {
412 continue;
413 }
414
415 let last = *last_pong.lock().await;
417 if last.elapsed() > pong_timeout {
418 let mut guard = inner.lock().await;
420 if let Some(ref mut sink) = guard.sink {
421 let _ = sink.close().await;
422 }
423 connected.store(false, Ordering::SeqCst);
424 continue;
425 }
426
427 let mut guard = inner.lock().await;
429 if let Some(ref mut sink) = guard.sink {
430 let _ = sink.send(WsMsg::Ping(Vec::new())).await;
431 }
432 }
433 }
434
435 async fn reconnect_loop(
436 url: &str,
437 config: &WsConfig,
438 inner: Arc<Mutex<WsInner>>,
439 connected: Arc<AtomicBool>,
440 should_run: Arc<AtomicBool>,
441 last_pong: Arc<Mutex<Instant>>,
442 lifecycle_tx: Arc<broadcast::Sender<WsLifecycleEvent>>,
443 ) {
444 let mut delay = config.base_delay;
445 let mut attempts = 0;
446
447 while should_run.load(Ordering::SeqCst) {
448 if config.max_attempts > 0 && attempts >= config.max_attempts {
449 should_run.store(false, Ordering::SeqCst);
451 let mut guard = inner.lock().await;
452 let reason = "Connection lost after max retries".to_string();
453 guard.close_all_senders_with_error(&reason);
454 let _ = lifecycle_tx.send(WsLifecycleEvent::Disconnected {
455 reason,
456 final_: true,
457 });
458 return;
459 }
460
461 let _ = lifecycle_tx.send(WsLifecycleEvent::Reconnecting {
462 attempt: attempts + 1,
463 delay,
464 });
465 tokio::time::sleep(delay).await;
466 attempts += 1;
467
468 match tokio_tungstenite::connect_async(url).await {
469 Ok((ws_stream, _)) => {
470 let (sink, stream) = ws_stream.split();
471
472 {
473 let mut guard = inner.lock().await;
474 guard.sink = Some(sink);
475 }
476
477 connected.store(true, Ordering::SeqCst);
478 *last_pong.lock().await = Instant::now();
479
480 {
482 let mut guard = inner.lock().await;
483 let subs = guard.subscriptions.clone();
484 if let Some(ref mut sink) = guard.sink {
485 for sub in &subs {
486 let text = serde_json::to_string(sub).unwrap_or_default();
487 let _ = sink.send(WsMsg::Text(text)).await;
488 }
489 }
490 }
491 let _ = lifecycle_tx.send(WsLifecycleEvent::Reconnected { attempts });
492
493 Self::read_loop(
495 stream,
496 inner.clone(),
497 connected.clone(),
498 should_run.clone(),
499 last_pong.clone(),
500 )
501 .await;
502
503 if should_run.load(Ordering::SeqCst) {
505 connected.store(false, Ordering::SeqCst);
506 delay = config.base_delay;
507 attempts = 0;
508 continue;
509 }
510 return;
511 }
512 Err(_) => {
513 delay = (delay * 2).min(config.max_delay);
514 }
515 }
516 }
517 }
518
519 pub fn subscribe_lifecycle(&self) -> broadcast::Receiver<WsLifecycleEvent> {
521 self.lifecycle_tx.subscribe()
522 }
523
524 async fn send_json(&self, value: serde_json::Value) -> Result<(), O2Error> {
525 let text = serde_json::to_string(&value)?;
526 let mut guard = self.inner.lock().await;
527 if let Some(ref mut sink) = guard.sink {
528 sink.send(WsMsg::Text(text))
529 .await
530 .map_err(|e| O2Error::WebSocketError(e.to_string()))
531 } else {
532 Err(O2Error::WebSocketError("Not connected".into()))
533 }
534 }
535
536 fn add_subscription(inner: &mut WsInner, sub: serde_json::Value) {
537 if !inner.subscriptions.contains(&sub) {
538 inner.subscriptions.push(sub);
539 }
540 }
541
542 pub fn is_connected(&self) -> bool {
544 self.connected.load(Ordering::SeqCst)
545 }
546
547 pub async fn stream_depth(
557 &self,
558 market_id: &str,
559 precision: &DepthPrecision,
560 ) -> Result<TypedStream<DepthUpdate>, O2Error> {
561 let (tx, rx) = mpsc::unbounded_channel();
562 let sub = json!({
563 "action": "subscribe_depth",
564 "market_id": market_id,
565 "precision": precision.as_str()
566 });
567
568 {
569 let mut guard = self.inner.lock().await;
570 guard.depth_senders.push(tx);
571 Self::add_subscription(&mut guard, sub.clone());
572 }
573
574 self.send_json(sub).await?;
575 Ok(TypedStream { rx })
576 }
577
578 pub async fn stream_orders(
580 &self,
581 identities: &[Identity],
582 ) -> Result<TypedStream<OrderUpdate>, O2Error> {
583 let (tx, rx) = mpsc::unbounded_channel();
584 let sub = json!({
585 "action": "subscribe_orders",
586 "identities": identities
587 });
588
589 {
590 let mut guard = self.inner.lock().await;
591 guard.orders_senders.push(tx);
592 Self::add_subscription(&mut guard, sub.clone());
593 }
594
595 self.send_json(sub).await?;
596 Ok(TypedStream { rx })
597 }
598
599 pub async fn stream_trades(
601 &self,
602 market_id: &str,
603 ) -> Result<TypedStream<TradeUpdate>, O2Error> {
604 let (tx, rx) = mpsc::unbounded_channel();
605 let sub = json!({
606 "action": "subscribe_trades",
607 "market_id": market_id
608 });
609
610 {
611 let mut guard = self.inner.lock().await;
612 guard.trades_senders.push(tx);
613 Self::add_subscription(&mut guard, sub.clone());
614 }
615
616 self.send_json(sub).await?;
617 Ok(TypedStream { rx })
618 }
619
620 pub async fn stream_balances(
622 &self,
623 identities: &[Identity],
624 ) -> Result<TypedStream<BalanceUpdate>, O2Error> {
625 let (tx, rx) = mpsc::unbounded_channel();
626 let sub = json!({
627 "action": "subscribe_balances",
628 "identities": identities
629 });
630
631 {
632 let mut guard = self.inner.lock().await;
633 guard.balances_senders.push(tx);
634 Self::add_subscription(&mut guard, sub.clone());
635 }
636
637 self.send_json(sub).await?;
638 Ok(TypedStream { rx })
639 }
640
641 pub async fn stream_nonce(
643 &self,
644 identities: &[Identity],
645 ) -> Result<TypedStream<NonceUpdate>, O2Error> {
646 let (tx, rx) = mpsc::unbounded_channel();
647 let sub = json!({
648 "action": "subscribe_nonce",
649 "identities": identities
650 });
651
652 {
653 let mut guard = self.inner.lock().await;
654 guard.nonce_senders.push(tx);
655 Self::add_subscription(&mut guard, sub.clone());
656 }
657
658 self.send_json(sub).await?;
659 Ok(TypedStream { rx })
660 }
661
662 pub async fn unsubscribe_depth(&self, market_id: &str) -> Result<(), O2Error> {
664 self.send_json(json!({
665 "action": "unsubscribe_depth",
666 "market_id": market_id
667 }))
668 .await?;
669 let mut guard = self.inner.lock().await;
670 guard.subscriptions.retain(|s| {
671 !(s.get("action").and_then(|a| a.as_str()) == Some("subscribe_depth")
672 && s.get("market_id").and_then(|m| m.as_str()) == Some(market_id))
673 });
674 Ok(())
675 }
676
677 pub async fn unsubscribe_orders(&self) -> Result<(), O2Error> {
679 let unsub = json!({
680 "action": "unsubscribe_orders"
681 });
682 self.send_json(unsub).await?;
683 let mut guard = self.inner.lock().await;
684 guard
687 .subscriptions
688 .retain(|s| s.get("action").and_then(|a| a.as_str()) != Some("subscribe_orders"));
689 Ok(())
690 }
691
692 pub async fn unsubscribe_trades(&self, market_id: &str) -> Result<(), O2Error> {
694 self.send_json(json!({
695 "action": "unsubscribe_trades",
696 "market_id": market_id
697 }))
698 .await?;
699 let mut guard = self.inner.lock().await;
700 guard.subscriptions.retain(|s| {
701 !(s.get("action").and_then(|a| a.as_str()) == Some("subscribe_trades")
702 && s.get("market_id").and_then(|m| m.as_str()) == Some(market_id))
703 });
704 Ok(())
705 }
706
707 pub async fn unsubscribe_balances(&self, identities: &[Identity]) -> Result<(), O2Error> {
709 let unsub = json!({
710 "action": "unsubscribe_balances",
711 "identities": identities
712 });
713 self.send_json(unsub).await?;
714 let mut guard = self.inner.lock().await;
715 let exact_sub = json!({
716 "action": "subscribe_balances",
717 "identities": identities
718 });
719 guard.subscriptions.retain(|s| s != &exact_sub);
720 Ok(())
721 }
722
723 pub async fn unsubscribe_nonce(&self, identities: &[Identity]) -> Result<(), O2Error> {
725 let unsub = json!({
726 "action": "unsubscribe_nonce",
727 "identities": identities
728 });
729 self.send_json(unsub).await?;
730 let mut guard = self.inner.lock().await;
731 let exact_sub = json!({
732 "action": "subscribe_nonce",
733 "identities": identities
734 });
735 guard.subscriptions.retain(|s| s != &exact_sub);
736 Ok(())
737 }
738
739 pub fn is_terminated(&self) -> bool {
742 !self.should_run.load(Ordering::SeqCst)
743 }
744
745 pub async fn disconnect(&self) -> Result<(), O2Error> {
747 self.should_run.store(false, Ordering::SeqCst);
748 self.connected.store(false, Ordering::SeqCst);
749
750 let mut guard = self.inner.lock().await;
752 if let Some(ref mut sink) = guard.sink {
753 let _ = sink.send(WsMsg::Close(None)).await;
754 }
755
756 let _ = self.lifecycle_tx.send(WsLifecycleEvent::Disconnected {
759 reason: "Explicit disconnect".to_string(),
760 final_: true,
761 });
762
763 guard.close_all_senders();
765
766 Ok(())
767 }
768}
769
770impl Drop for O2WebSocket {
771 fn drop(&mut self) {
772 self.should_run.store(false, Ordering::SeqCst);
773 if let Some(h) = self.reader_handle.take() {
774 h.abort();
775 }
776 if let Some(h) = self.ping_handle.take() {
777 h.abort();
778 }
779 }
780}