diff --git a/.gitignore b/.gitignore
index d676b90..88509c8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,7 @@
tmp
.env
!.idea/httpRequests
+.venv
+abacus.exe
+*.bru
+bruno.json
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
index aae8b7c..4c6280e 100644
--- a/.idea/vcs.xml
+++ b/.idea/vcs.xml
@@ -9,4 +9,4 @@
-
+
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
index 8ee2a34..ad5a34e 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -18,5 +18,5 @@ CMD ["/abacus"]
HEALTHCHECK --interval=10s --timeout=3s --start-period=5s --retries=3 CMD wget -S -O - http://0.0.0.0:8080/healthcheck || exit 1
LABEL maintainer="Jason Cameron abacus@jasoncameron.dev"
-LABEL version="1.3.3"
+LABEL version="1.4.0"
LABEL description="This is a simple countAPI service written in Go."
diff --git a/docs/bugs/GHSA-vh64-54px-qgf8/results.md b/docs/bugs/GHSA-vh64-54px-qgf8/results.md
new file mode 100644
index 0000000..7a873bd
--- /dev/null
+++ b/docs/bugs/GHSA-vh64-54px-qgf8/results.md
@@ -0,0 +1,101 @@
+
+# After malicious script
+```bash
+❯ curl localhost:8080/stream/stream/eee/ -vvv
+23:01:20.486586 [0-x] == Info: [READ] client_reset, clear readers
+23:01:20.489178 [0-0] == Info: Host localhost:8080 was resolved.
+23:01:20.491295 [0-0] == Info: IPv6: ::1
+23:01:20.492584 [0-0] == Info: IPv4: 127.0.0.1
+23:01:20.494114 [0-0] == Info: [SETUP] added
+23:01:20.495946 [0-0] == Info: Trying [::1]:8080...
+23:01:20.498051 [0-0] == Info: Connected to localhost (::1) port 8080
+23:01:20.500158 [0-0] == Info: using HTTP/1.x
+23:01:20.502118 [0-0] => Send header, 96 bytes (0x60)
+0000: GET /stream/stream/eee/ HTTP/1.1
+0022: Host: localhost:8080
+0038: User-Agent: curl/8.10.1
+0051: Accept: */*
+005e:
+23:01:20.507492 [0-0] == Info: Request completely sent off
+23:01:41.832824 [0-0] == Info: Recv failure: Connection was reset
+23:01:41.835258 [0-0] == Info: [WRITE] cw-out done
+23:01:41.836819 [0-0] == Info: closing connection #0
+23:01:41.838722 [0-0] == Info: [SETUP] close
+23:01:41.840378 [0-0] == Info: [SETUP] destroy
+curl: (56) Recv failure: Connection was reset
+~ via v3.13.1 took 21s
+```
+
+# Before malicious script
+```bash
+❯ curl localhost:8080/stream/stream/eee/ -vvv
+23:01:48.679494 [0-x] == Info: [READ] client_reset, clear readers
+23:01:48.682091 [0-0] == Info: Host localhost:8080 was resolved.
+23:01:48.684116 [0-0] == Info: IPv6: ::1
+23:01:48.685387 [0-0] == Info: IPv4: 127.0.0.1
+23:01:48.687056 [0-0] == Info: [SETUP] added
+23:01:48.688790 [0-0] == Info: Trying [::1]:8080...
+23:01:48.690916 [0-0] == Info: Connected to localhost (::1) port 8080
+23:01:48.692898 [0-0] == Info: using HTTP/1.x
+23:01:48.694577 [0-0] => Send header, 96 bytes (0x60)
+0000: GET /stream/stream/eee/ HTTP/1.1
+0022: Host: localhost:8080
+0038: User-Agent: curl/8.10.1
+0051: Accept: */*
+005e:
+23:01:48.699693 [0-0] == Info: Request completely sent off
+23:01:48.722188 [0-0] <= Recv header, 17 bytes (0x11)
+0000: HTTP/1.1 200 OK
+23:01:48.724437 [0-0] == Info: [WRITE] cw_out, wrote 17 header bytes -> 17
+23:01:48.726908 [0-0] == Info: [WRITE] download_write header(type=c, blen=17) -> 0
+23:01:48.729546 [0-0] == Info: [WRITE] client_write(type=c, len=17) -> 0
+23:01:48.731849 [0-0] <= Recv header, 25 bytes (0x19)
+0000: Cache-Control: no-cache
+23:01:48.734304 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=25) -> 0
+23:01:48.736960 [0-0] == Info: [WRITE] cw_out, wrote 25 header bytes -> 25
+23:01:48.739246 [0-0] == Info: [WRITE] download_write header(type=4, blen=25) -> 0
+23:01:48.741791 [0-0] == Info: [WRITE] client_write(type=4, len=25) -> 0
+23:01:48.744142 [0-0] <= Recv header, 24 bytes (0x18)
+0000: Connection: keep-alive
+23:01:48.746735 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=24) -> 0
+23:01:48.749113 [0-0] == Info: [WRITE] cw_out, wrote 24 header bytes -> 24
+23:01:48.751155 [0-0] == Info: [WRITE] download_write header(type=4, blen=24) -> 0
+23:01:48.753506 [0-0] == Info: [WRITE] client_write(type=4, len=24) -> 0
+23:01:48.755627 [0-0] <= Recv header, 33 bytes (0x21)
+0000: Content-Type: text/event-stream
+23:01:48.758653 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=33) -> 0
+23:01:48.761138 [0-0] == Info: [WRITE] cw_out, wrote 33 header bytes -> 33
+23:01:48.763515 [0-0] == Info: [WRITE] download_write header(type=4, blen=33) -> 0
+23:01:48.766179 [0-0] == Info: [WRITE] client_write(type=4, len=33) -> 0
+23:01:48.768501 [0-0] <= Recv header, 37 bytes (0x25)
+0000: Date: Sun, 02 Mar 2025 04:01:48 GMT
+23:01:48.771446 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=37) -> 0
+23:01:48.773906 [0-0] == Info: [WRITE] cw_out, wrote 37 header bytes -> 37
+23:01:48.776037 [0-0] == Info: [WRITE] download_write header(type=4, blen=37) -> 0
+23:01:48.778591 [0-0] == Info: [WRITE] client_write(type=4, len=37) -> 0
+23:01:48.780857 [0-0] == Info: [WRITE] looking for transfer decoder: chunked
+23:01:48.783072 [0-0] == Info: [WRITE] added transfer decoder chunked -> 0
+23:01:48.785259 [0-0] <= Recv header, 28 bytes (0x1c)
+0000: Transfer-Encoding: chunked
+23:01:48.788082 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=28) -> 0
+23:01:48.790781 [0-0] == Info: [WRITE] cw_out, wrote 28 header bytes -> 28
+23:01:48.792926 [0-0] == Info: [WRITE] download_write header(type=4, blen=28) -> 0
+23:01:48.795653 [0-0] == Info: [WRITE] client_write(type=4, len=28) -> 0
+23:01:48.797787 [0-0] <= Recv header, 2 bytes (0x2)
+0000:
+23:01:48.799552 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=2) -> 0
+23:01:48.802100 [0-0] == Info: [WRITE] cw_out, wrote 2 header bytes -> 2
+23:01:48.804303 [0-0] == Info: [WRITE] download_write header(type=4, blen=2) -> 0
+23:01:48.806762 [0-0] == Info: [WRITE] client_write(type=4, len=2) -> 0
+23:01:48.808821 [0-0] <= Recv data, 26 bytes (0x1a)
+0000: 14
+0004: data: {"value":31}..
+23:01:48.811317 [0-0] == Info: [WRITE] http_chunked, chunk start of 20 bytes
+data: {"value":31}
+
+23:01:48.813647 [0-0] == Info: [WRITE] cw_out, wrote 20 body bytes -> 20
+23:01:48.815844 [0-0] == Info: [WRITE] download_write body(type=1, blen=20) -> 0
+23:01:48.818394 [0-0] == Info: [WRITE] http_chunked, write 20 body bytes, 0 bytes in chunk remain
+23:01:48.821546 [0-0] == Info: [WRITE] client_write(type=1, len=26) -> 0
+23:01:48.823790 [0-0] == Info: [WRITE] xfer_write_resp(len=192, eos=0) -> 0
+```
diff --git a/docs/bugs/GHSA-vh64-54px-qgf8/test.py b/docs/bugs/GHSA-vh64-54px-qgf8/test.py
new file mode 100644
index 0000000..ceac67d
--- /dev/null
+++ b/docs/bugs/GHSA-vh64-54px-qgf8/test.py
@@ -0,0 +1,498 @@
+#!/usr/bin/env python3
+"""
+Enhanced test script for Abacus SSE goroutine leak.
+
+This script creates three types of connections:
+1. Quick connections that immediately disconnect
+2. Lingering connections that stay open for a longer period
+3. Zombie connections that are left open until the end of the test
+
+This mixed approach better simulates real-world client behavior and tests
+the server's ability to clean up all types of disconnected clients.
+
+Author: JasonLovesDoggo (2025-03-02)
+"""
+
+import argparse
+import concurrent.futures
+import json
+import os
+import platform
+import signal
+import sys
+import threading
+import time
+from datetime import datetime
+try:
+ import psutil
+ PSUTIL_AVAILABLE = True
+except ImportError:
+ PSUTIL_AVAILABLE = False
+
+import requests
+import colorama
+
+# Initialize colorama for Windows terminal support
+colorama.init()
+
+class AbacusLeakTester:
+ def __init__(self, server_url, process_name='abacus',
+ quick_connections=50, lingering_connections=20, zombie_connections=10,
+ num_batches=5, delay_between_batches=5, linger_time=10,
+ timeout=3, max_workers=10):
+ self.server_url = server_url.rstrip('/')
+ self.endpoint = "/stream/test/leak_test"
+ self.process_name = process_name
+
+ # Connection types and counts
+ self.quick_connections = quick_connections
+ self.lingering_connections = lingering_connections
+ self.zombie_connections = zombie_connections
+
+ self.num_batches = num_batches
+ self.delay_between_batches = delay_between_batches
+ self.linger_time = linger_time
+ self.timeout = timeout
+ self.max_workers = max_workers
+ self.process = None
+ self.initial_memory = None
+
+ # Tracking the zombie connections
+ self.zombie_sessions = []
+ self.zombie_responses = []
+
+ # Results tracking
+ self.total_successful = 0
+ self.total_failed = 0
+ self.memory_readings = []
+
+ # Windows-specific process name adjustments
+ if platform.system() == "Windows":
+ if not self.process_name.lower().endswith('.exe'):
+ self.process_name += '.exe'
+
+ def find_process(self):
+ """Find the server process if running locally."""
+ if not PSUTIL_AVAILABLE:
+ print("psutil module not available. Memory tracking disabled.")
+ return None
+
+ for proc in psutil.process_iter(['pid', 'name']):
+ try:
+ if self.process_name.lower() in proc.info['name'].lower():
+ return proc
+ except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
+ pass
+
+ return None
+
+ def get_memory_usage(self):
+ """Get current memory usage of the process in MB."""
+ if not self.process:
+ return None
+
+ try:
+ # Wait a bit for any memory operations to settle
+ time.sleep(1)
+ self.process.memory_info() # Refresh process info
+ memory = self.process.memory_info().rss / (1024 * 1024)
+ self.memory_readings.append(memory)
+ return memory
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ print("Process no longer accessible or has terminated.")
+ self.process = None
+ return None
+
+ def create_test_counter(self):
+ """Create a test counter if it doesn't exist."""
+ try:
+ response = requests.post(f"{self.server_url}/create/test/leak_test")
+ if response.status_code == 201:
+ print("\033[92mCreated test counter\033[0m")
+ else:
+ print(f"\033[93mCounter creation response: {response.status_code}\033[0m")
+ except Exception as e:
+ print(f"\033[93mCounter may already exist, continuing... ({e})\033[0m")
+
+ def make_quick_connection(self, connection_id, batch_id):
+ """Make a connection that immediately disconnects."""
+ session = requests.Session()
+ try:
+ # Start a streaming request
+ headers = {"Accept": "text/event-stream"}
+ response = session.get(
+ f"{self.server_url}{self.endpoint}",
+ headers=headers,
+ stream=True,
+ timeout=self.timeout
+ )
+
+ # Just read a tiny bit of data
+ try:
+ next(response.iter_content(chunk_size=64))
+ except (StopIteration, requests.exceptions.ChunkedEncodingError,
+ requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout):
+ pass
+
+ # Abruptly close the connection
+ response.close()
+
+ # Increment the counter
+ hit_response = session.get(f"{self.server_url}/hit/test/leak_test")
+ hit_response.raise_for_status()
+
+ return True
+ except Exception as e:
+ if "Read timed out" not in str(e): # Ignore expected timeouts
+ print(f"Quick connection {connection_id} in batch {batch_id} failed: {str(e)}")
+ return False
+ finally:
+ session.close()
+
+ def make_lingering_connection(self, connection_id, batch_id):
+ """Make a connection that stays open for a while before closing."""
+ session = requests.Session()
+ try:
+ # Start a streaming request
+ headers = {"Accept": "text/event-stream"}
+ response = session.get(
+ f"{self.server_url}{self.endpoint}",
+ headers=headers,
+ stream=True,
+ timeout=self.timeout + self.linger_time
+ )
+
+ # Read a bit of data
+ try:
+ for _ in range(2): # Read a couple of chunks to establish connection
+ next(response.iter_content(chunk_size=128))
+ except (StopIteration, requests.exceptions.ChunkedEncodingError):
+ pass
+
+ print(f" Lingering connection {connection_id} established, will stay open for {self.linger_time}s")
+
+ # Keep connection open for a while
+ time.sleep(self.linger_time)
+
+ # Properly close the connection
+ response.close()
+
+ # Increment the counter
+ hit_response = session.get(f"{self.server_url}/hit/test/leak_test")
+ hit_response.raise_for_status()
+
+ print(f" Lingering connection {connection_id} properly closed after {self.linger_time}s")
+ return True
+ except Exception as e:
+ print(f"Lingering connection {connection_id} in batch {batch_id} failed: {str(e)}")
+ return False
+ finally:
+ session.close()
+
+ def make_zombie_connection(self, connection_id, batch_id):
+ """Make a connection that is never explicitly closed (until cleanup)."""
+ try:
+ # Create a persistent session
+ session = requests.Session()
+ self.zombie_sessions.append(session)
+
+ # Start a streaming request
+ headers = {"Accept": "text/event-stream"}
+ response = session.get(
+ f"{self.server_url}{self.endpoint}",
+ headers=headers,
+ stream=True,
+ timeout=60 # Long timeout
+ )
+ self.zombie_responses.append(response)
+
+ # Read just a bit to establish the connection
+ try:
+ next(response.iter_content(chunk_size=64))
+ except (StopIteration, requests.exceptions.ChunkedEncodingError):
+ return False
+
+ print(f" Zombie connection {connection_id} established (will remain open)")
+
+ # Increment the counter
+ hit_response = requests.get(f"{self.server_url}/hit/test/leak_test")
+ hit_response.raise_for_status()
+
+ return True
+ except Exception as e:
+ print(f"Zombie connection {connection_id} in batch {batch_id} failed: {str(e)}")
+ return False
+
+ def cleanup_zombie_connections(self):
+ """Clean up any zombie connections at the end of the test."""
+ print("\n\033[93mCleaning up zombie connections...\033[0m")
+ for i, response in enumerate(self.zombie_responses):
+ try:
+ response.close()
+ print(f" Closed zombie connection {i+1}")
+ except:
+ pass
+
+ for i, session in enumerate(self.zombie_sessions):
+ try:
+ session.close()
+ print(f" Closed zombie session {i+1}")
+ except:
+ pass
+
+ # Clear the lists
+ self.zombie_responses = []
+ self.zombie_sessions = []
+
+ def run_batch(self, batch_id):
+ """Run a batch with different connection types."""
+ print(f"\033[95mStarting batch {batch_id} of {self.num_batches}\033[0m")
+
+ batch_start = time.time()
+ batch_successful = 0
+ batch_failed = 0
+
+ # 1. Quick connections (in parallel)
+ if self.quick_connections > 0:
+ print(f" Creating {self.quick_connections} quick connections...")
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
+ futures = []
+ for i in range(1, self.quick_connections + 1):
+ futures.append(executor.submit(self.make_quick_connection, i, batch_id))
+ if i % 20 == 0:
+ time.sleep(0.2) # Stagger connections
+
+ for future in concurrent.futures.as_completed(futures):
+ if future.result():
+ batch_successful += 1
+ else:
+ batch_failed += 1
+
+ # 2. Lingering connections (in parallel but with careful thread management)
+ if self.lingering_connections > 0:
+ print(f" Creating {self.lingering_connections} lingering connections...")
+ # Use a smaller thread pool to avoid overwhelming resources
+ max_lingering_threads = min(5, self.max_workers)
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_lingering_threads) as executor:
+ futures = []
+ for i in range(1, self.lingering_connections + 1):
+ futures.append(executor.submit(self.make_lingering_connection, i, batch_id))
+ time.sleep(0.5) # Stagger lingering connections more
+
+ for future in concurrent.futures.as_completed(futures):
+ if future.result():
+ batch_successful += 1
+ else:
+ batch_failed += 1
+
+ # 3. Zombie connections (create but don't close)
+ if batch_id == 1 and self.zombie_connections > 0: # Only create zombies in first batch
+ print(f" Creating {self.zombie_connections} zombie connections...")
+ for i in range(1, self.zombie_connections + 1):
+ if self.make_zombie_connection(i, batch_id):
+ batch_successful += 1
+ else:
+ batch_failed += 1
+ time.sleep(0.5) # Delay between zombie connections
+
+ self.total_successful += batch_successful
+ self.total_failed += batch_failed
+
+ batch_duration = time.time() - batch_start
+ print(f"\033[92mBatch {batch_id} completed: {batch_successful} successful, "
+ f"{batch_failed} failed. Duration: {batch_duration:.2f}s\033[0m")
+
+ return batch_successful, batch_failed
+
+ def check_memory(self):
+ """Check and report on memory usage."""
+ if not self.process:
+ return True
+
+ current_memory = self.get_memory_usage()
+ if current_memory is None:
+ print("\033[91mProcess no longer found. Server may have crashed!\033[0m")
+ return False
+
+ memory_diff = current_memory - self.initial_memory
+ color = "\033[96m" # Cyan
+ if memory_diff > 15:
+ color = "\033[91m" # Red
+ elif memory_diff > 5:
+ color = "\033[93m" # Yellow
+
+ sign = "+" if memory_diff >= 0 else ""
+ print(f"{color}Current memory: {current_memory:.2f} MB ({sign}{memory_diff:.2f} MB)\033[0m")
+ return True
+
+ def print_final_stats(self):
+ """Print final test statistics."""
+ final_memory = None
+ if self.process:
+ final_memory = self.get_memory_usage()
+
+ print("\n\033[95mTest completed!\033[0m")
+ total_connections = (
+ (self.quick_connections * self.num_batches) +
+ (self.lingering_connections * self.num_batches) +
+ self.zombie_connections
+ )
+ print(f"\033[97mTotal connections attempted: {total_connections}\033[0m")
+ print(f" Quick connections: {self.quick_connections * self.num_batches}")
+ print(f" Lingering connections: {self.lingering_connections * self.num_batches}")
+ print(f" Zombie connections: {self.zombie_connections}")
+
+ print(f"\033[92mSuccessful: {self.total_successful}\033[0m")
+ color = "\033[92m" if self.total_failed == 0 else "\033[91m"
+ print(f"{color}Failed: {self.total_failed}\033[0m")
+
+ if self.initial_memory is not None and len(self.memory_readings) > 1:
+ print(f"\n\033[96mMemory Analysis:\033[0m")
+ print(f"\033[97mInitial memory: {self.initial_memory:.2f} MB\033[0m")
+ print(f"\033[97mFinal memory: {final_memory:.2f} MB\033[0m")
+
+ memory_growth = final_memory - self.initial_memory
+ growth_percent = (memory_growth / self.initial_memory) * 100
+
+ sign = "+" if memory_growth >= 0 else ""
+ color = "\033[92m" # Green
+ if growth_percent > 20:
+ color = "\033[91m" # Red
+ elif growth_percent > 10:
+ color = "\033[93m" # Yellow
+
+ print(f"{color}Growth: {sign}{memory_growth:.2f} MB ({sign}{growth_percent:.2f}%)\033[0m")
+
+ # Check for consistent growth pattern
+ if len(self.memory_readings) >= 3:
+ print("\n\033[96mMemory Growth Pattern:\033[0m")
+ for i in range(1, len(self.memory_readings)):
+ diff = self.memory_readings[i] - self.memory_readings[i-1]
+ print(f" Batch {i}: {self.memory_readings[i]:.2f} MB ({'+' if diff >= 0 else ''}{diff:.2f} MB)")
+
+ # Check for leak indicators
+ consistent_growth = True
+ baseline_diff = self.memory_readings[1] - self.memory_readings[0]
+ for i in range(2, len(self.memory_readings)):
+ diff = self.memory_readings[i] - self.memory_readings[i-1]
+ # If growth is inconsistent (allowing for some variance)
+ if diff < 0 or abs(diff - baseline_diff) > max(baseline_diff * 0.5, 1.0):
+ consistent_growth = False
+
+ if memory_growth > 10 and consistent_growth:
+ print(f"\033[91mConsistent memory growth detected across batches!")
+ print(f"This strongly indicates a memory/goroutine leak.\033[0m")
+ elif memory_growth > 10:
+ print(f"\033[93mSignificant memory growth detected but pattern is inconsistent.")
+ print(f"This may indicate a partial leak or normal memory variation.\033[0m")
+ elif memory_growth > 5:
+ print(f"\033[93mModerate memory growth detected. May be normal variation.\033[0m")
+ else:
+ print(f"\033[92mMemory usage appears stable. No obvious leak detected.\033[0m")
+
+ def get_final_counter(self):
+ """Get the final counter value."""
+ try:
+ response = requests.get(f"{self.server_url}/get/test/leak_test")
+ counter_value = response.json().get('value', 'unknown')
+ print(f"\n\033[96mCounter value after test: {counter_value}\033[0m")
+ except Exception as e:
+ print(f"\n\033[91mCould not get final counter value: {e}\033[0m")
+
+ def run_test(self):
+ """Run the complete test."""
+ print(f"Testing Abacus SSE endpoint at {self.server_url}")
+ print(f"Running on {platform.system()} {platform.release()}")
+ print(f"Connection configuration:")
+ print(f" - Quick connections: {self.quick_connections}/batch")
+ print(f" - Lingering connections: {self.lingering_connections}/batch (stay open for {self.linger_time}s)")
+ print(f" - Zombie connections: {self.zombie_connections} (left open until end)")
+ print(f" - Total batches: {self.num_batches}")
+
+ # Find process for memory tracking
+ self.process = self.find_process()
+ if self.process:
+ self.initial_memory = self.get_memory_usage()
+ print(f"\033[96mInitial memory usage: {self.initial_memory:.2f} MB\033[0m")
+ else:
+ print("\033[93mCould not find local process. Memory tracking disabled.\033[0m")
+
+ # Create the counter
+ self.create_test_counter()
+
+ try:
+ # Run test batches
+ for batch in range(1, self.num_batches + 1):
+ self.run_batch(batch)
+
+ # Check memory
+ if self.process and not self.check_memory():
+ break
+
+ # Delay between batches
+ if batch < self.num_batches:
+ print(f"\033[90mWaiting {self.delay_between_batches} seconds before next batch...\033[0m")
+ time.sleep(self.delay_between_batches)
+
+ # After all batches, wait a bit longer to see if memory stabilizes
+ print("\n\033[93mWaiting 10 seconds for memory to stabilize...\033[0m")
+ time.sleep(10)
+
+ # Final memory check
+ if self.process:
+ self.check_memory()
+
+ # Print statistics
+ self.print_final_stats()
+ self.get_final_counter()
+
+ finally:
+ # Always clean up zombie connections
+ self.cleanup_zombie_connections()
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Test for goroutine leaks in Abacus SSE implementation')
+ parser.add_argument('--url', default='http://localhost:8080', help='Abacus server URL')
+ parser.add_argument('--process', default='abacus', help='Process name for memory tracking')
+ parser.add_argument('--quick', type=int, default=50, help='Quick connections per batch')
+ parser.add_argument('--lingering', type=int, default=20, help='Lingering connections per batch')
+ parser.add_argument('--zombie', type=int, default=10, help='Total zombie connections')
+ parser.add_argument('--batches', type=int, default=5, help='Number of batches')
+ parser.add_argument('--delay', type=int, default=5, help='Delay between batches (seconds)')
+ parser.add_argument('--linger', type=int, default=10, help='How long lingering connections stay open (seconds)')
+ parser.add_argument('--workers', type=int, default=10, help='Max concurrent connections')
+
+ args = parser.parse_args()
+
+ # Platform-specific adjustments
+ if platform.system() == "Windows":
+ if args.workers > 10:
+ args.workers = 10
+ if args.quick > 30:
+ args.quick = 30
+
+ tester = AbacusLeakTester(
+ server_url=args.url,
+ process_name=args.process,
+ quick_connections=args.quick,
+ lingering_connections=args.lingering,
+ zombie_connections=args.zombie,
+ num_batches=args.batches,
+ delay_between_batches=args.delay,
+ linger_time=args.linger,
+ max_workers=args.workers
+ )
+
+ try:
+ tester.run_test()
+ except KeyboardInterrupt:
+ print("\n\033[93mTest interrupted by user.\033[0m")
+ # Clean up even on keyboard interrupt
+ tester.cleanup_zombie_connections()
+ finally:
+ colorama.deinit()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/docs/index.html b/docs/index.html
index 9c2eadc..b0da263 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -569,7 +569,7 @@
/stats
"expired_keys__since_restart": "130", // number of keys expired since db's last restart
"key_misses__since_restart": "205", // number of keys not found since db's last restart
"total_keys": 87904, // total number of keys created
- "version": "1.3.3", // Abacus's version
+ "version": "1.4.0", // Abacus's version
"shard": "boujee-coorgi", // Handler shard
"uptime": "1h23m45s" // shard uptime
diff --git a/main.go b/main.go
index 0c351d8..608aa0a 100644
--- a/main.go
+++ b/main.go
@@ -31,7 +31,7 @@ import (
const (
DocsUrl string = "https://jasoncameron.dev/abacus/"
- Version string = "1.3.3"
+ Version string = "1.4.0"
)
var (
diff --git a/middleware/sse.go b/middleware/sse.go
index 4c2c049..b0334fa 100644
--- a/middleware/sse.go
+++ b/middleware/sse.go
@@ -1,6 +1,8 @@
package middleware
-import "github.com/gin-gonic/gin"
+import (
+ "github.com/gin-gonic/gin"
+)
func SSEMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
@@ -8,6 +10,7 @@ func SSEMiddleware() gin.HandlerFunc {
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
+ c.Writer.Header().Set("X-Accel-Buffering", "no") // Disable proxy buffering
c.Next()
}
}
diff --git a/routes.go b/routes.go
index 6d2fb8a..a7a9516 100644
--- a/routes.go
+++ b/routes.go
@@ -3,6 +3,7 @@ package main
import (
"context"
"errors"
+ "fmt"
"io"
"log"
"math"
@@ -38,95 +39,95 @@ func StreamValueView(c *gin.Context) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
- c.Header("Access-Control-Allow-Origin", "*")
- // Initialize client channel with buffer to prevent deadlock
- clientChan := make(chan int, 5) // Buffer a few values to prevent blocking
+ // Initialize client channel with a buffer to prevent blocking
+ clientChan := make(chan int, 10)
- // Create a cancellable context for this client
- ctx, cancel := context.WithCancel(c.Request.Context())
- defer cancel() // Ensure context is always canceled
+ // Create a context that's canceled when the client disconnects
+ ctx := c.Request.Context()
- // Client registration
+ // Add this client to the event server for this specific key
utils.ValueEventServer.NewClients <- utils.KeyClientPair{
Key: dbKey,
Client: clientChan,
}
- // Use a proper mutex-protected flag to track cleanup state
- var cleanupOnce sync.Once
- cleanup := func() {
- cleanupOnce.Do(func() {
- log.Printf("Cleaning up client for key %s", dbKey)
+ // Track if cleanup has been done
+ var cleanupDone bool
+ var cleanupMutex sync.Mutex
- // Signal that this client is closed
+ // Ensure client is always removed when handler exits
+ defer func() {
+ cleanupMutex.Lock()
+ if !cleanupDone {
+ cleanupDone = true
+ cleanupMutex.Unlock()
+
+ // Signal the event server to remove this client
select {
- case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{
- Key: dbKey,
- Client: clientChan,
- }:
- // Successfully sent close signal
+ case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{Key: dbKey, Client: clientChan}:
+ // Successfully sent cleanup signal
case <-time.After(500 * time.Millisecond):
- log.Printf("Warning: Timed out sending client closure signal for %s", dbKey)
+ // Timed out waiting to send cleanup signal
+ log.Printf("Warning: Timed out sending cleanup signal for %s", dbKey)
}
+ } else {
+ cleanupMutex.Unlock()
+ }
+ }()
- // Use a separate goroutine to drain any remaining messages
- // This prevents blocking the cleanup function
- go func() {
- timeout := time.NewTimer(1 * time.Second)
- defer timeout.Stop()
-
- // Drain any pending messages
- for {
- select {
- case _, ok := <-clientChan:
- if !ok {
- return // Channel already closed
- }
- // Discard message
- case <-timeout.C:
- // Safety timeout
- return
- }
- }
- }()
- })
- }
-
- // Ensure cleanup runs when handler exits
- defer cleanup()
-
- // Monitor for client disconnection
+ // Monitor for client disconnection in a separate goroutine
go func() {
- select {
- case <-ctx.Done(): // Context done = client disconnected or request canceled
- cleanup()
+ <-ctx.Done() // Wait for context cancellation (client disconnected)
+
+ cleanupMutex.Lock()
+ if !cleanupDone {
+ cleanupDone = true
+ cleanupMutex.Unlock()
+
+ log.Printf("Client disconnected for key %s, cleaning up", dbKey)
+
+ // Signal the event server to remove this client
+ select {
+ case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{Key: dbKey, Client: clientChan}:
+ // Successfully sent cleanup signal
+ case <-time.After(500 * time.Millisecond):
+ // Timed out waiting to send cleanup signal
+ log.Printf("Warning: Timed out sending cleanup signal for %s after disconnect", dbKey)
+ }
+ } else {
+ cleanupMutex.Unlock()
}
}()
// Send initial value
initialVal := Client.Get(context.Background(), dbKey).Val()
if count, err := strconv.Atoi(initialVal); err == nil {
- c.SSEvent("message", map[string]int{"value": count})
+ // Keep your exact format
+ _, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count))
+ if err != nil {
+ log.Printf("Error writing to client: %v", err)
+ return
+ }
c.Writer.Flush()
}
- // Stream updates with clear error handling
+ // Stream updates
c.Stream(func(w io.Writer) bool {
select {
case <-ctx.Done():
- log.Printf("Client context done for key %s", dbKey)
return false
-
case count, ok := <-clientChan:
if !ok {
- log.Printf("Client channel closed for key %s", dbKey)
return false
}
-
- // Use SSEvent for consistent formatting
- c.SSEvent("message", map[string]int{"value": count})
-
+ // Keep your exact format
+ _, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count))
+ if err != nil {
+ log.Printf("Error writing to client: %v", err)
+ return false
+ }
+ c.Writer.Flush()
return true
}
})
diff --git a/routes_test.go b/routes_test.go
index 33e07b0..c7c7a30 100644
--- a/routes_test.go
+++ b/routes_test.go
@@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"os"
+ "strings"
"testing"
"time"
@@ -482,6 +483,10 @@ func TestStreamValueView(t *testing.T) {
w := newMockResponseWriter()
req, _ := http.NewRequest("GET", "/stream/test/stream_key", nil)
+ // Create a new context with cancellation
+ requestCtx, cancelFunc := context.WithCancel(req.Context())
+ req = req.WithContext(requestCtx)
+
// Channel to signal test completion
done := make(chan struct{})
go func() {
@@ -489,10 +494,21 @@ func TestStreamValueView(t *testing.T) {
r.ServeHTTP(w, req)
}()
- // Give the stream some time to start
- time.Sleep(100 * time.Millisecond)
+ // Wait for initial response
+ waitForContains := func(w *mockResponseWriter, expected string, timeout time.Duration) bool {
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ if strings.Contains(w.Body.String(), expected) {
+ return true
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ return false
+ }
- assert.Contains(t, w.Header().Get("Content-Type"), "text/event-stream")
+ if !waitForContains(w, "data:", 500*time.Millisecond) {
+ t.Fatal("Initial SSE connection not established in time")
+ }
// Hit the key to generate updates
hitReq, _ := http.NewRequest("GET", "/hit/test/stream_key", nil)
@@ -500,17 +516,28 @@ func TestStreamValueView(t *testing.T) {
// Trigger updates
hitW := httptest.NewRecorder()
r.ServeHTTP(hitW, hitReq)
- time.Sleep(50 * time.Millisecond) // Allow the stream to process
- assert.Contains(t, w.Body.String(), "data: {\"value\":1}\n\n")
- r.ServeHTTP(hitW, hitReq) // Hit it again
- time.Sleep(50 * time.Millisecond) // Allow the stream to process
- assert.Contains(t, w.Body.String(), "data: {\"value\":2}\n\n")
+ // Check for value 1 with timeout
+ if !waitForContains(w, "data: {\"value\":1}\n\n", 500*time.Millisecond) {
+ t.Fatal("Did not receive first update in time")
+ }
+
+ r.ServeHTTP(hitW, hitReq) // Hit it again
+
+ // Check for value 2 with timeout
+ if !waitForContains(w, "data: {\"value\":2}\n\n", 500*time.Millisecond) {
+ t.Fatal("Did not receive second update in time")
+ }
// Signal the stream to stop
+ cancelFunc()
+
+ // Wait for goroutine to finish with timeout
select {
case <-done:
- case <-time.After(1 * time.Second): // Ensure test doesn't hang forever
+ // Test completed successfully
+ case <-time.After(1 * time.Second):
+ t.Fatal("Test timed out waiting for stream to close")
}
})
}
diff --git a/stream_test.go b/stream_test.go
new file mode 100644
index 0000000..2cd89d8
--- /dev/null
+++ b/stream_test.go
@@ -0,0 +1,379 @@
+package main
+
+import (
+ "bufio"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/jasonlovesdoggo/abacus/utils"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// MockCloseNotifier implements http.CloseNotifier for testing
+type MockResponseRecorder struct {
+ *httptest.ResponseRecorder
+ closeNotify chan bool
+}
+
+// NewMockResponseRecorder creates a new response recorder with CloseNotify support
+func NewMockResponseRecorder() *MockResponseRecorder {
+ return &MockResponseRecorder{
+ ResponseRecorder: httptest.NewRecorder(),
+ closeNotify: make(chan bool, 1),
+ }
+}
+
+// CloseNotify implements http.CloseNotifier
+func (m *MockResponseRecorder) CloseNotify() <-chan bool {
+ return m.closeNotify
+}
+
+// Close simulates a client disconnection
+func (m *MockResponseRecorder) Close() {
+ select {
+ case m.closeNotify <- true:
+ // Signal sent
+ default:
+ // Channel already has a value or is closed
+ }
+}
+
+// TestStreamBasicFunctionality tests that the stream endpoint correctly
+// sends events when values are updated
+func TestStreamBasicFunctionality(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ router := setupTestRouter()
+
+ // Create a counter first
+ createResp := httptest.NewRecorder()
+ createReq, _ := http.NewRequest("POST", "/create/test/stream-test", nil)
+ router.ServeHTTP(createResp, createReq)
+ assert.Equal(t, http.StatusCreated, createResp.Code)
+
+ // For streaming tests, we need a real HTTP server
+ server := httptest.NewServer(router)
+ defer server.Close()
+
+ // Use a real HTTP client to connect to the server
+ client := &http.Client{
+ Timeout: 5 * time.Second,
+ }
+
+ req, err := http.NewRequest("GET", server.URL+"/stream/test/stream-test", nil)
+ require.NoError(t, err)
+
+ req.Header.Set("Accept", "text/event-stream")
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Channel to collect received events
+ events := make(chan string, 10)
+ done := make(chan struct{})
+
+ // Process the SSE stream
+ go func() {
+ defer close(done)
+ scanner := bufio.NewScanner(resp.Body)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if strings.HasPrefix(line, "data: ") {
+ select {
+ case events <- line:
+ // Event sent
+ case <-time.After(100 * time.Millisecond):
+ // Buffer full, drop event
+ t.Logf("Event buffer full, dropped: %s", line)
+ }
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ t.Logf("Scanner error: %v", err)
+ }
+ }()
+
+ // Wait for initial value
+ select {
+ case event := <-events:
+ assert.True(t, strings.HasPrefix(event, "data: {\"value\":"))
+ case <-time.After(1 * time.Second):
+ t.Fatal("Timeout waiting for initial event")
+ }
+
+ // Hit the counter to increment its value
+ hitResp, err := client.Get(server.URL + "/hit/test/stream-test")
+ require.NoError(t, err)
+ hitResp.Body.Close()
+ assert.Equal(t, http.StatusOK, hitResp.StatusCode)
+
+ // Check that we got an update event
+ select {
+ case event := <-events:
+ assert.True(t, strings.HasPrefix(event, "data: {\"value\":"))
+
+ // Extract the value
+ value := extractValueFromEvent(event)
+ assert.Equal(t, 1, value)
+ case <-time.After(1 * time.Second):
+ t.Fatal("Timeout waiting for update event")
+ }
+
+ // Close connection
+ resp.Body.Close()
+
+ // Give some time for cleanup
+ time.Sleep(500 * time.Millisecond)
+
+ // Verify proper cleanup
+ clientCount := countClientsForKey("K:test:stream-test")
+ assert.Equal(t, 0, clientCount, "Client wasn't properly cleaned up after disconnection")
+}
+
+// TestMultipleClients tests multiple clients connecting to the same stream
+func TestMultipleClients(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ router := setupTestRouter()
+
+ // Create a counter
+ createResp := httptest.NewRecorder()
+ createReq, _ := http.NewRequest("POST", "/create/test/multi-client", nil)
+ router.ServeHTTP(createResp, createReq)
+ assert.Equal(t, http.StatusCreated, createResp.Code)
+
+ // Start a real HTTP server
+ server := httptest.NewServer(router)
+ defer server.Close()
+
+ // Number of clients to test
+ numClients := 3 // Reduced from 5 for faster testing
+
+ // Set up client trackers
+ type clientState struct {
+ resp *http.Response
+ events chan string
+ done chan struct{}
+ lastValue int
+ eventCount int
+ }
+
+ clients := make([]*clientState, numClients)
+
+ // Start all clients
+ for i := 0; i < numClients; i++ {
+ // Create client state
+ clients[i] = &clientState{
+ events: make(chan string, 10),
+ done: make(chan struct{}),
+ }
+
+ // Create request
+ req, err := http.NewRequest("GET", server.URL+"/stream/test/multi-client", nil)
+ require.NoError(t, err)
+ req.Header.Set("Accept", "text/event-stream")
+
+ // Connect client
+ client := &http.Client{
+ Timeout: 5 * time.Second,
+ }
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ clients[i].resp = resp
+
+ // Process events
+ go func(idx int) {
+ defer close(clients[idx].done)
+ scanner := bufio.NewScanner(clients[idx].resp.Body)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if strings.HasPrefix(line, "data: ") {
+ select {
+ case clients[idx].events <- line:
+ // Event sent
+ default:
+ // Buffer full, drop event
+ }
+ }
+ }
+ }(i)
+ }
+
+ // Give time for all clients to connect
+ time.Sleep(300 * time.Millisecond)
+
+ // Verify all clients receive initial value
+ for i := 0; i < numClients; i++ {
+ select {
+ case event := <-clients[i].events:
+ clients[i].lastValue = extractValueFromEvent(event)
+ clients[i].eventCount++
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timeout waiting for client %d initial event", i)
+ }
+ }
+
+ // Hit the counter several times
+ for hits := 0; hits < 3; hits++ {
+ client := &http.Client{}
+ resp, err := client.Get(server.URL + "/hit/test/multi-client")
+ require.NoError(t, err)
+ resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ // Give time for events to propagate
+ time.Sleep(100 * time.Millisecond)
+ }
+
+ // Verify all clients received the updates
+ for i := 0; i < numClients; i++ {
+ // Drain all events
+ timeout := time.After(500 * time.Millisecond)
+ draining := true
+
+ for draining {
+ select {
+ case event := <-clients[i].events:
+ clients[i].lastValue = extractValueFromEvent(event)
+ clients[i].eventCount++
+ case <-timeout:
+ draining = false
+ }
+ }
+
+ // Each client should have received at least 4 events (initial + 3 hits)
+ assert.GreaterOrEqual(t, clients[i].eventCount, 4, "Client %d didn't receive enough events", i)
+ assert.Equal(t, 3, clients[i].lastValue, "Client %d has incorrect final value", i)
+ }
+
+ // Disconnect clients one by one and verify cleanup
+ for i := 0; i < numClients; i++ {
+ // Close client connection
+ clients[i].resp.Body.Close()
+
+ // Give time for cleanup
+ time.Sleep(200 * time.Millisecond)
+
+ // Verify decreasing client count
+ clientCount := countClientsForKey("K:test:multi-client")
+ assert.Equal(t, numClients-(i+1), clientCount, "Client wasn't properly cleaned up after disconnection")
+ }
+}
+
+// TestStreamConcurrencyStress tests the stream under high concurrency conditions
+func TestStreamConcurrencyStress(t *testing.T) {
+ // Skip in normal testing as this is a long stress test
+ if testing.Short() {
+ t.Skip("Skipping stress test in short mode")
+ }
+
+ gin.SetMode(gin.ReleaseMode) // Reduce logging noise
+ router := setupTestRouter()
+
+ // Create a counter for stress testing
+ createResp := httptest.NewRecorder()
+ createReq, _ := http.NewRequest("POST", "/create/test/stress-test", nil)
+ router.ServeHTTP(createResp, createReq)
+ require.Equal(t, http.StatusCreated, createResp.Code)
+
+ // Start a real HTTP server
+ server := httptest.NewServer(router)
+ defer server.Close()
+
+ // Test parameters
+ numClients := 20 // Reduced from 50 for faster testing
+ clientDuration := 300 * time.Millisecond
+
+ // Start with no clients
+ initialCount := countClientsForKey("K:test:stress-test")
+ assert.Equal(t, 0, initialCount)
+
+ // Launch many concurrent clients
+ var wg sync.WaitGroup
+ for i := 0; i < numClients; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+
+ // Create client
+ client := &http.Client{}
+
+ // Create request
+ req, err := http.NewRequest("GET", server.URL+"/stream/test/stress-test", nil)
+ if err != nil {
+ t.Logf("Error creating request: %v", err)
+ return
+ }
+ req.Header.Set("Accept", "text/event-stream")
+
+ // Send request
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Logf("Error connecting: %v", err)
+ return
+ }
+
+ // Keep connection open for the duration
+ time.Sleep(clientDuration)
+
+ // Close connection
+ resp.Body.Close()
+ }(i)
+
+ // Stagger client creation slightly
+ time.Sleep(5 * time.Millisecond)
+ }
+
+ // Wait for all clients to finish
+ wg.Wait()
+
+ // Give extra time for any cleanup
+ time.Sleep(1 * time.Second)
+
+ // Verify all clients were cleaned up
+ finalCount := countClientsForKey("K:test:stress-test")
+ assert.Equal(t, 0, finalCount, "Not all clients were cleaned up after stress test")
+
+ // Check we can still connect new clients
+ client := &http.Client{}
+ req, err := http.NewRequest("GET", server.URL+"/stream/test/stress-test", nil)
+ require.NoError(t, err)
+ req.Header.Set("Accept", "text/event-stream")
+
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Give time for connection
+ time.Sleep(200 * time.Millisecond)
+
+ // Verify new client connected
+ newCount := countClientsForKey("K:test:stress-test")
+ assert.Equal(t, 1, newCount, "Failed to connect new client after stress test")
+
+ // Clean up
+ resp.Body.Close()
+ time.Sleep(200 * time.Millisecond)
+}
+
+func countClientsForKey(key string) int {
+ return utils.ValueEventServer.CountClientsForKey(key)
+}
+
+func extractValueFromEvent(event string) int {
+ // Format is "data: {"value":X}"
+ jsonStr := strings.TrimPrefix(event, "data: ")
+ var data struct {
+ Value int `json:"value"`
+ }
+ err := json.Unmarshal([]byte(jsonStr), &data)
+ if err != nil {
+ return -1
+ }
+ return data.Value
+}
diff --git a/utils/sse.go b/utils/sse.go
index fb9b72f..b09d155 100644
--- a/utils/sse.go
+++ b/utils/sse.go
@@ -3,6 +3,7 @@ package utils
import (
"log"
"sync"
+ "time"
)
type ValueEvent struct {
@@ -25,9 +26,10 @@ type KeyClientPair struct {
func NewValueEventServer() *ValueEvent {
event := &ValueEvent{
- Message: make(chan KeyValue),
- NewClients: make(chan KeyClientPair),
- ClosedClients: make(chan KeyClientPair),
+ // Use buffered channels to prevent blocking
+ Message: make(chan KeyValue, 100),
+ NewClients: make(chan KeyClientPair, 100),
+ ClosedClients: make(chan KeyClientPair, 100),
TotalClients: make(map[string]map[chan int]bool),
}
go event.listen()
@@ -48,26 +50,85 @@ func (v *ValueEvent) listen() {
case closedClient := <-v.ClosedClients:
v.Mu.Lock()
- delete(v.TotalClients[closedClient.Key], closedClient.Client)
- close(closedClient.Client)
+ if clients, exists := v.TotalClients[closedClient.Key]; exists {
+ if _, ok := clients[closedClient.Client]; ok {
+ delete(clients, closedClient.Client)
- // Clean up key map if no more clients
- if len(v.TotalClients[closedClient.Key]) == 0 {
- delete(v.TotalClients, closedClient.Key)
+ // Close channel safely
+ close(closedClient.Client)
+
+ log.Printf("Removed client for key %s", closedClient.Key)
+
+ // Clean up key map if no more clients
+ if len(clients) == 0 {
+ delete(v.TotalClients, closedClient.Key)
+ log.Printf("No more clients for key %s, removed key entry", closedClient.Key)
+ }
+ }
}
v.Mu.Unlock()
- log.Printf("Removed client for key %s", closedClient.Key)
case keyValue := <-v.Message:
+ // First, get a snapshot of clients under read lock
v.Mu.RLock()
- for clientChan := range v.TotalClients[keyValue.Key] {
- clientChan <- keyValue.Value
+ clients, exists := v.TotalClients[keyValue.Key]
+ if !exists || len(clients) == 0 {
+ v.Mu.RUnlock()
+ continue
+ }
+
+ // Create a safe copy of client channels
+ clientChannels := make([]chan int, 0, len(clients))
+ for clientChan := range clients {
+ clientChannels = append(clientChannels, clientChan)
}
v.Mu.RUnlock()
+
+ // Send messages without holding the lock
+ // Track which clients failed to receive
+ var failedClients []chan int
+ for _, clientChan := range clientChannels {
+ select {
+ case clientChan <- keyValue.Value:
+ // Message sent successfully
+ case <-time.After(100 * time.Millisecond):
+ // Client not responding, mark for removal
+ failedClients = append(failedClients, clientChan)
+ }
+ }
+
+ // Schedule removal of failed clients
+ for _, failedClient := range failedClients {
+ select {
+ case v.ClosedClients <- KeyClientPair{Key: keyValue.Key, Client: failedClient}:
+ // Client scheduled for removal
+ default:
+ // If ClosedClients channel is full, try again later
+ go func(key string, client chan int) {
+ time.Sleep(200 * time.Millisecond)
+ select {
+ case v.ClosedClients <- KeyClientPair{Key: key, Client: client}:
+ // Success on retry
+ default:
+ log.Printf("Failed to remove client for key %s even after retry", key)
+ }
+ }(keyValue.Key, failedClient)
+ }
+ }
}
}
}
+func (v *ValueEvent) CountClientsForKey(key string) int {
+ v.Mu.RLock()
+ defer v.Mu.RUnlock()
+
+ if clients, exists := v.TotalClients[key]; exists {
+ return len(clients)
+ }
+ return 0
+}
+
// Global event server
var ValueEventServer *ValueEvent
@@ -77,10 +138,12 @@ func init() {
// When you want to update a value and notify clients for a specific key
func SetStream(dbKey string, newValue int) {
- // Broadcast the new value only to clients listening to this specific key
- ValueEventServer.Message <- KeyValue{
- Key: dbKey,
- Value: newValue,
+ // Use a non-blocking send with default case to prevent blocking
+ select {
+ case ValueEventServer.Message <- KeyValue{Key: dbKey, Value: newValue}:
+ // Message sent successfully
+ default:
+ log.Printf("Warning: Message channel full, update for key %s dropped", dbKey)
}
}
@@ -100,8 +163,11 @@ func CloseStream(dbKey string) {
ValueEventServer.Mu.Unlock()
// Now close the channels after releasing the lock
- // This ensures we're not holding the lock while performing potentially blocking operations
for _, ch := range channelsToClose {
close(ch)
}
+
+ if len(channelsToClose) > 0 {
+ log.Printf("Closed all streams for key %s (%d clients)", dbKey, len(channelsToClose))
+ }
}
diff --git a/utils/stats.go b/utils/stats.go
index 2e5e757..58886bf 100644
--- a/utils/stats.go
+++ b/utils/stats.go
@@ -150,6 +150,9 @@ func (sm *StatManager) monitorHealth() {
// Log that we're skipping save despite high total because buffer is empty
log.Printf("Total count high (%d/%d) but buffer is empty. Skipping unnecessary save operation.",
snapshot.Total, totalWarningThreshold)
+ } else {
+ log.Printf("Buffer (%d/%d) and total count (%d/%d) are within acceptable limits. Skipping save operation.",
+ len(sm.buffer), batchSize, snapshot.Total, totalWarningThreshold)
}
}
}