diff --git a/.gitignore b/.gitignore index d676b90..88509c8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,7 @@ tmp .env !.idea/httpRequests +.venv +abacus.exe +*.bru +bruno.json diff --git a/.idea/vcs.xml b/.idea/vcs.xml index aae8b7c..4c6280e 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -9,4 +9,4 @@ - + \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 8ee2a34..ad5a34e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,5 +18,5 @@ CMD ["/abacus"] HEALTHCHECK --interval=10s --timeout=3s --start-period=5s --retries=3 CMD wget -S -O - http://0.0.0.0:8080/healthcheck || exit 1 LABEL maintainer="Jason Cameron abacus@jasoncameron.dev" -LABEL version="1.3.3" +LABEL version="1.4.0" LABEL description="This is a simple countAPI service written in Go." diff --git a/docs/bugs/GHSA-vh64-54px-qgf8/results.md b/docs/bugs/GHSA-vh64-54px-qgf8/results.md new file mode 100644 index 0000000..7a873bd --- /dev/null +++ b/docs/bugs/GHSA-vh64-54px-qgf8/results.md @@ -0,0 +1,101 @@ + +# After malicious script +```bash +❯ curl localhost:8080/stream/stream/eee/ -vvv +23:01:20.486586 [0-x] == Info: [READ] client_reset, clear readers +23:01:20.489178 [0-0] == Info: Host localhost:8080 was resolved. +23:01:20.491295 [0-0] == Info: IPv6: ::1 +23:01:20.492584 [0-0] == Info: IPv4: 127.0.0.1 +23:01:20.494114 [0-0] == Info: [SETUP] added +23:01:20.495946 [0-0] == Info: Trying [::1]:8080... +23:01:20.498051 [0-0] == Info: Connected to localhost (::1) port 8080 +23:01:20.500158 [0-0] == Info: using HTTP/1.x +23:01:20.502118 [0-0] => Send header, 96 bytes (0x60) +0000: GET /stream/stream/eee/ HTTP/1.1 +0022: Host: localhost:8080 +0038: User-Agent: curl/8.10.1 +0051: Accept: */* +005e: +23:01:20.507492 [0-0] == Info: Request completely sent off +23:01:41.832824 [0-0] == Info: Recv failure: Connection was reset +23:01:41.835258 [0-0] == Info: [WRITE] cw-out done +23:01:41.836819 [0-0] == Info: closing connection #0 +23:01:41.838722 [0-0] == Info: [SETUP] close +23:01:41.840378 [0-0] == Info: [SETUP] destroy +curl: (56) Recv failure: Connection was reset +~ via  v3.13.1 took 21s +``` + +# Before malicious script +```bash +❯ curl localhost:8080/stream/stream/eee/ -vvv +23:01:48.679494 [0-x] == Info: [READ] client_reset, clear readers +23:01:48.682091 [0-0] == Info: Host localhost:8080 was resolved. +23:01:48.684116 [0-0] == Info: IPv6: ::1 +23:01:48.685387 [0-0] == Info: IPv4: 127.0.0.1 +23:01:48.687056 [0-0] == Info: [SETUP] added +23:01:48.688790 [0-0] == Info: Trying [::1]:8080... +23:01:48.690916 [0-0] == Info: Connected to localhost (::1) port 8080 +23:01:48.692898 [0-0] == Info: using HTTP/1.x +23:01:48.694577 [0-0] => Send header, 96 bytes (0x60) +0000: GET /stream/stream/eee/ HTTP/1.1 +0022: Host: localhost:8080 +0038: User-Agent: curl/8.10.1 +0051: Accept: */* +005e: +23:01:48.699693 [0-0] == Info: Request completely sent off +23:01:48.722188 [0-0] <= Recv header, 17 bytes (0x11) +0000: HTTP/1.1 200 OK +23:01:48.724437 [0-0] == Info: [WRITE] cw_out, wrote 17 header bytes -> 17 +23:01:48.726908 [0-0] == Info: [WRITE] download_write header(type=c, blen=17) -> 0 +23:01:48.729546 [0-0] == Info: [WRITE] client_write(type=c, len=17) -> 0 +23:01:48.731849 [0-0] <= Recv header, 25 bytes (0x19) +0000: Cache-Control: no-cache +23:01:48.734304 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=25) -> 0 +23:01:48.736960 [0-0] == Info: [WRITE] cw_out, wrote 25 header bytes -> 25 +23:01:48.739246 [0-0] == Info: [WRITE] download_write header(type=4, blen=25) -> 0 +23:01:48.741791 [0-0] == Info: [WRITE] client_write(type=4, len=25) -> 0 +23:01:48.744142 [0-0] <= Recv header, 24 bytes (0x18) +0000: Connection: keep-alive +23:01:48.746735 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=24) -> 0 +23:01:48.749113 [0-0] == Info: [WRITE] cw_out, wrote 24 header bytes -> 24 +23:01:48.751155 [0-0] == Info: [WRITE] download_write header(type=4, blen=24) -> 0 +23:01:48.753506 [0-0] == Info: [WRITE] client_write(type=4, len=24) -> 0 +23:01:48.755627 [0-0] <= Recv header, 33 bytes (0x21) +0000: Content-Type: text/event-stream +23:01:48.758653 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=33) -> 0 +23:01:48.761138 [0-0] == Info: [WRITE] cw_out, wrote 33 header bytes -> 33 +23:01:48.763515 [0-0] == Info: [WRITE] download_write header(type=4, blen=33) -> 0 +23:01:48.766179 [0-0] == Info: [WRITE] client_write(type=4, len=33) -> 0 +23:01:48.768501 [0-0] <= Recv header, 37 bytes (0x25) +0000: Date: Sun, 02 Mar 2025 04:01:48 GMT +23:01:48.771446 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=37) -> 0 +23:01:48.773906 [0-0] == Info: [WRITE] cw_out, wrote 37 header bytes -> 37 +23:01:48.776037 [0-0] == Info: [WRITE] download_write header(type=4, blen=37) -> 0 +23:01:48.778591 [0-0] == Info: [WRITE] client_write(type=4, len=37) -> 0 +23:01:48.780857 [0-0] == Info: [WRITE] looking for transfer decoder: chunked +23:01:48.783072 [0-0] == Info: [WRITE] added transfer decoder chunked -> 0 +23:01:48.785259 [0-0] <= Recv header, 28 bytes (0x1c) +0000: Transfer-Encoding: chunked +23:01:48.788082 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=28) -> 0 +23:01:48.790781 [0-0] == Info: [WRITE] cw_out, wrote 28 header bytes -> 28 +23:01:48.792926 [0-0] == Info: [WRITE] download_write header(type=4, blen=28) -> 0 +23:01:48.795653 [0-0] == Info: [WRITE] client_write(type=4, len=28) -> 0 +23:01:48.797787 [0-0] <= Recv header, 2 bytes (0x2) +0000: +23:01:48.799552 [0-0] == Info: [WRITE] header_collect pushed(type=1, len=2) -> 0 +23:01:48.802100 [0-0] == Info: [WRITE] cw_out, wrote 2 header bytes -> 2 +23:01:48.804303 [0-0] == Info: [WRITE] download_write header(type=4, blen=2) -> 0 +23:01:48.806762 [0-0] == Info: [WRITE] client_write(type=4, len=2) -> 0 +23:01:48.808821 [0-0] <= Recv data, 26 bytes (0x1a) +0000: 14 +0004: data: {"value":31}.. +23:01:48.811317 [0-0] == Info: [WRITE] http_chunked, chunk start of 20 bytes +data: {"value":31} + +23:01:48.813647 [0-0] == Info: [WRITE] cw_out, wrote 20 body bytes -> 20 +23:01:48.815844 [0-0] == Info: [WRITE] download_write body(type=1, blen=20) -> 0 +23:01:48.818394 [0-0] == Info: [WRITE] http_chunked, write 20 body bytes, 0 bytes in chunk remain +23:01:48.821546 [0-0] == Info: [WRITE] client_write(type=1, len=26) -> 0 +23:01:48.823790 [0-0] == Info: [WRITE] xfer_write_resp(len=192, eos=0) -> 0 +``` diff --git a/docs/bugs/GHSA-vh64-54px-qgf8/test.py b/docs/bugs/GHSA-vh64-54px-qgf8/test.py new file mode 100644 index 0000000..ceac67d --- /dev/null +++ b/docs/bugs/GHSA-vh64-54px-qgf8/test.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python3 +""" +Enhanced test script for Abacus SSE goroutine leak. + +This script creates three types of connections: +1. Quick connections that immediately disconnect +2. Lingering connections that stay open for a longer period +3. Zombie connections that are left open until the end of the test + +This mixed approach better simulates real-world client behavior and tests +the server's ability to clean up all types of disconnected clients. + +Author: JasonLovesDoggo (2025-03-02) +""" + +import argparse +import concurrent.futures +import json +import os +import platform +import signal +import sys +import threading +import time +from datetime import datetime +try: + import psutil + PSUTIL_AVAILABLE = True +except ImportError: + PSUTIL_AVAILABLE = False + +import requests +import colorama + +# Initialize colorama for Windows terminal support +colorama.init() + +class AbacusLeakTester: + def __init__(self, server_url, process_name='abacus', + quick_connections=50, lingering_connections=20, zombie_connections=10, + num_batches=5, delay_between_batches=5, linger_time=10, + timeout=3, max_workers=10): + self.server_url = server_url.rstrip('/') + self.endpoint = "/stream/test/leak_test" + self.process_name = process_name + + # Connection types and counts + self.quick_connections = quick_connections + self.lingering_connections = lingering_connections + self.zombie_connections = zombie_connections + + self.num_batches = num_batches + self.delay_between_batches = delay_between_batches + self.linger_time = linger_time + self.timeout = timeout + self.max_workers = max_workers + self.process = None + self.initial_memory = None + + # Tracking the zombie connections + self.zombie_sessions = [] + self.zombie_responses = [] + + # Results tracking + self.total_successful = 0 + self.total_failed = 0 + self.memory_readings = [] + + # Windows-specific process name adjustments + if platform.system() == "Windows": + if not self.process_name.lower().endswith('.exe'): + self.process_name += '.exe' + + def find_process(self): + """Find the server process if running locally.""" + if not PSUTIL_AVAILABLE: + print("psutil module not available. Memory tracking disabled.") + return None + + for proc in psutil.process_iter(['pid', 'name']): + try: + if self.process_name.lower() in proc.info['name'].lower(): + return proc + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + pass + + return None + + def get_memory_usage(self): + """Get current memory usage of the process in MB.""" + if not self.process: + return None + + try: + # Wait a bit for any memory operations to settle + time.sleep(1) + self.process.memory_info() # Refresh process info + memory = self.process.memory_info().rss / (1024 * 1024) + self.memory_readings.append(memory) + return memory + except (psutil.NoSuchProcess, psutil.AccessDenied): + print("Process no longer accessible or has terminated.") + self.process = None + return None + + def create_test_counter(self): + """Create a test counter if it doesn't exist.""" + try: + response = requests.post(f"{self.server_url}/create/test/leak_test") + if response.status_code == 201: + print("\033[92mCreated test counter\033[0m") + else: + print(f"\033[93mCounter creation response: {response.status_code}\033[0m") + except Exception as e: + print(f"\033[93mCounter may already exist, continuing... ({e})\033[0m") + + def make_quick_connection(self, connection_id, batch_id): + """Make a connection that immediately disconnects.""" + session = requests.Session() + try: + # Start a streaming request + headers = {"Accept": "text/event-stream"} + response = session.get( + f"{self.server_url}{self.endpoint}", + headers=headers, + stream=True, + timeout=self.timeout + ) + + # Just read a tiny bit of data + try: + next(response.iter_content(chunk_size=64)) + except (StopIteration, requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout): + pass + + # Abruptly close the connection + response.close() + + # Increment the counter + hit_response = session.get(f"{self.server_url}/hit/test/leak_test") + hit_response.raise_for_status() + + return True + except Exception as e: + if "Read timed out" not in str(e): # Ignore expected timeouts + print(f"Quick connection {connection_id} in batch {batch_id} failed: {str(e)}") + return False + finally: + session.close() + + def make_lingering_connection(self, connection_id, batch_id): + """Make a connection that stays open for a while before closing.""" + session = requests.Session() + try: + # Start a streaming request + headers = {"Accept": "text/event-stream"} + response = session.get( + f"{self.server_url}{self.endpoint}", + headers=headers, + stream=True, + timeout=self.timeout + self.linger_time + ) + + # Read a bit of data + try: + for _ in range(2): # Read a couple of chunks to establish connection + next(response.iter_content(chunk_size=128)) + except (StopIteration, requests.exceptions.ChunkedEncodingError): + pass + + print(f" Lingering connection {connection_id} established, will stay open for {self.linger_time}s") + + # Keep connection open for a while + time.sleep(self.linger_time) + + # Properly close the connection + response.close() + + # Increment the counter + hit_response = session.get(f"{self.server_url}/hit/test/leak_test") + hit_response.raise_for_status() + + print(f" Lingering connection {connection_id} properly closed after {self.linger_time}s") + return True + except Exception as e: + print(f"Lingering connection {connection_id} in batch {batch_id} failed: {str(e)}") + return False + finally: + session.close() + + def make_zombie_connection(self, connection_id, batch_id): + """Make a connection that is never explicitly closed (until cleanup).""" + try: + # Create a persistent session + session = requests.Session() + self.zombie_sessions.append(session) + + # Start a streaming request + headers = {"Accept": "text/event-stream"} + response = session.get( + f"{self.server_url}{self.endpoint}", + headers=headers, + stream=True, + timeout=60 # Long timeout + ) + self.zombie_responses.append(response) + + # Read just a bit to establish the connection + try: + next(response.iter_content(chunk_size=64)) + except (StopIteration, requests.exceptions.ChunkedEncodingError): + return False + + print(f" Zombie connection {connection_id} established (will remain open)") + + # Increment the counter + hit_response = requests.get(f"{self.server_url}/hit/test/leak_test") + hit_response.raise_for_status() + + return True + except Exception as e: + print(f"Zombie connection {connection_id} in batch {batch_id} failed: {str(e)}") + return False + + def cleanup_zombie_connections(self): + """Clean up any zombie connections at the end of the test.""" + print("\n\033[93mCleaning up zombie connections...\033[0m") + for i, response in enumerate(self.zombie_responses): + try: + response.close() + print(f" Closed zombie connection {i+1}") + except: + pass + + for i, session in enumerate(self.zombie_sessions): + try: + session.close() + print(f" Closed zombie session {i+1}") + except: + pass + + # Clear the lists + self.zombie_responses = [] + self.zombie_sessions = [] + + def run_batch(self, batch_id): + """Run a batch with different connection types.""" + print(f"\033[95mStarting batch {batch_id} of {self.num_batches}\033[0m") + + batch_start = time.time() + batch_successful = 0 + batch_failed = 0 + + # 1. Quick connections (in parallel) + if self.quick_connections > 0: + print(f" Creating {self.quick_connections} quick connections...") + with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [] + for i in range(1, self.quick_connections + 1): + futures.append(executor.submit(self.make_quick_connection, i, batch_id)) + if i % 20 == 0: + time.sleep(0.2) # Stagger connections + + for future in concurrent.futures.as_completed(futures): + if future.result(): + batch_successful += 1 + else: + batch_failed += 1 + + # 2. Lingering connections (in parallel but with careful thread management) + if self.lingering_connections > 0: + print(f" Creating {self.lingering_connections} lingering connections...") + # Use a smaller thread pool to avoid overwhelming resources + max_lingering_threads = min(5, self.max_workers) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_lingering_threads) as executor: + futures = [] + for i in range(1, self.lingering_connections + 1): + futures.append(executor.submit(self.make_lingering_connection, i, batch_id)) + time.sleep(0.5) # Stagger lingering connections more + + for future in concurrent.futures.as_completed(futures): + if future.result(): + batch_successful += 1 + else: + batch_failed += 1 + + # 3. Zombie connections (create but don't close) + if batch_id == 1 and self.zombie_connections > 0: # Only create zombies in first batch + print(f" Creating {self.zombie_connections} zombie connections...") + for i in range(1, self.zombie_connections + 1): + if self.make_zombie_connection(i, batch_id): + batch_successful += 1 + else: + batch_failed += 1 + time.sleep(0.5) # Delay between zombie connections + + self.total_successful += batch_successful + self.total_failed += batch_failed + + batch_duration = time.time() - batch_start + print(f"\033[92mBatch {batch_id} completed: {batch_successful} successful, " + f"{batch_failed} failed. Duration: {batch_duration:.2f}s\033[0m") + + return batch_successful, batch_failed + + def check_memory(self): + """Check and report on memory usage.""" + if not self.process: + return True + + current_memory = self.get_memory_usage() + if current_memory is None: + print("\033[91mProcess no longer found. Server may have crashed!\033[0m") + return False + + memory_diff = current_memory - self.initial_memory + color = "\033[96m" # Cyan + if memory_diff > 15: + color = "\033[91m" # Red + elif memory_diff > 5: + color = "\033[93m" # Yellow + + sign = "+" if memory_diff >= 0 else "" + print(f"{color}Current memory: {current_memory:.2f} MB ({sign}{memory_diff:.2f} MB)\033[0m") + return True + + def print_final_stats(self): + """Print final test statistics.""" + final_memory = None + if self.process: + final_memory = self.get_memory_usage() + + print("\n\033[95mTest completed!\033[0m") + total_connections = ( + (self.quick_connections * self.num_batches) + + (self.lingering_connections * self.num_batches) + + self.zombie_connections + ) + print(f"\033[97mTotal connections attempted: {total_connections}\033[0m") + print(f" Quick connections: {self.quick_connections * self.num_batches}") + print(f" Lingering connections: {self.lingering_connections * self.num_batches}") + print(f" Zombie connections: {self.zombie_connections}") + + print(f"\033[92mSuccessful: {self.total_successful}\033[0m") + color = "\033[92m" if self.total_failed == 0 else "\033[91m" + print(f"{color}Failed: {self.total_failed}\033[0m") + + if self.initial_memory is not None and len(self.memory_readings) > 1: + print(f"\n\033[96mMemory Analysis:\033[0m") + print(f"\033[97mInitial memory: {self.initial_memory:.2f} MB\033[0m") + print(f"\033[97mFinal memory: {final_memory:.2f} MB\033[0m") + + memory_growth = final_memory - self.initial_memory + growth_percent = (memory_growth / self.initial_memory) * 100 + + sign = "+" if memory_growth >= 0 else "" + color = "\033[92m" # Green + if growth_percent > 20: + color = "\033[91m" # Red + elif growth_percent > 10: + color = "\033[93m" # Yellow + + print(f"{color}Growth: {sign}{memory_growth:.2f} MB ({sign}{growth_percent:.2f}%)\033[0m") + + # Check for consistent growth pattern + if len(self.memory_readings) >= 3: + print("\n\033[96mMemory Growth Pattern:\033[0m") + for i in range(1, len(self.memory_readings)): + diff = self.memory_readings[i] - self.memory_readings[i-1] + print(f" Batch {i}: {self.memory_readings[i]:.2f} MB ({'+' if diff >= 0 else ''}{diff:.2f} MB)") + + # Check for leak indicators + consistent_growth = True + baseline_diff = self.memory_readings[1] - self.memory_readings[0] + for i in range(2, len(self.memory_readings)): + diff = self.memory_readings[i] - self.memory_readings[i-1] + # If growth is inconsistent (allowing for some variance) + if diff < 0 or abs(diff - baseline_diff) > max(baseline_diff * 0.5, 1.0): + consistent_growth = False + + if memory_growth > 10 and consistent_growth: + print(f"\033[91mConsistent memory growth detected across batches!") + print(f"This strongly indicates a memory/goroutine leak.\033[0m") + elif memory_growth > 10: + print(f"\033[93mSignificant memory growth detected but pattern is inconsistent.") + print(f"This may indicate a partial leak or normal memory variation.\033[0m") + elif memory_growth > 5: + print(f"\033[93mModerate memory growth detected. May be normal variation.\033[0m") + else: + print(f"\033[92mMemory usage appears stable. No obvious leak detected.\033[0m") + + def get_final_counter(self): + """Get the final counter value.""" + try: + response = requests.get(f"{self.server_url}/get/test/leak_test") + counter_value = response.json().get('value', 'unknown') + print(f"\n\033[96mCounter value after test: {counter_value}\033[0m") + except Exception as e: + print(f"\n\033[91mCould not get final counter value: {e}\033[0m") + + def run_test(self): + """Run the complete test.""" + print(f"Testing Abacus SSE endpoint at {self.server_url}") + print(f"Running on {platform.system()} {platform.release()}") + print(f"Connection configuration:") + print(f" - Quick connections: {self.quick_connections}/batch") + print(f" - Lingering connections: {self.lingering_connections}/batch (stay open for {self.linger_time}s)") + print(f" - Zombie connections: {self.zombie_connections} (left open until end)") + print(f" - Total batches: {self.num_batches}") + + # Find process for memory tracking + self.process = self.find_process() + if self.process: + self.initial_memory = self.get_memory_usage() + print(f"\033[96mInitial memory usage: {self.initial_memory:.2f} MB\033[0m") + else: + print("\033[93mCould not find local process. Memory tracking disabled.\033[0m") + + # Create the counter + self.create_test_counter() + + try: + # Run test batches + for batch in range(1, self.num_batches + 1): + self.run_batch(batch) + + # Check memory + if self.process and not self.check_memory(): + break + + # Delay between batches + if batch < self.num_batches: + print(f"\033[90mWaiting {self.delay_between_batches} seconds before next batch...\033[0m") + time.sleep(self.delay_between_batches) + + # After all batches, wait a bit longer to see if memory stabilizes + print("\n\033[93mWaiting 10 seconds for memory to stabilize...\033[0m") + time.sleep(10) + + # Final memory check + if self.process: + self.check_memory() + + # Print statistics + self.print_final_stats() + self.get_final_counter() + + finally: + # Always clean up zombie connections + self.cleanup_zombie_connections() + + +def main(): + parser = argparse.ArgumentParser(description='Test for goroutine leaks in Abacus SSE implementation') + parser.add_argument('--url', default='http://localhost:8080', help='Abacus server URL') + parser.add_argument('--process', default='abacus', help='Process name for memory tracking') + parser.add_argument('--quick', type=int, default=50, help='Quick connections per batch') + parser.add_argument('--lingering', type=int, default=20, help='Lingering connections per batch') + parser.add_argument('--zombie', type=int, default=10, help='Total zombie connections') + parser.add_argument('--batches', type=int, default=5, help='Number of batches') + parser.add_argument('--delay', type=int, default=5, help='Delay between batches (seconds)') + parser.add_argument('--linger', type=int, default=10, help='How long lingering connections stay open (seconds)') + parser.add_argument('--workers', type=int, default=10, help='Max concurrent connections') + + args = parser.parse_args() + + # Platform-specific adjustments + if platform.system() == "Windows": + if args.workers > 10: + args.workers = 10 + if args.quick > 30: + args.quick = 30 + + tester = AbacusLeakTester( + server_url=args.url, + process_name=args.process, + quick_connections=args.quick, + lingering_connections=args.lingering, + zombie_connections=args.zombie, + num_batches=args.batches, + delay_between_batches=args.delay, + linger_time=args.linger, + max_workers=args.workers + ) + + try: + tester.run_test() + except KeyboardInterrupt: + print("\n\033[93mTest interrupted by user.\033[0m") + # Clean up even on keyboard interrupt + tester.cleanup_zombie_connections() + finally: + colorama.deinit() + + +if __name__ == '__main__': + main() diff --git a/docs/index.html b/docs/index.html index 9c2eadc..b0da263 100644 --- a/docs/index.html +++ b/docs/index.html @@ -569,7 +569,7 @@

/stats

"expired_keys__since_restart": "130", // number of keys expired since db's last restart "key_misses__since_restart": "205", // number of keys not found since db's last restart "total_keys": 87904, // total number of keys created - "version": "1.3.3", // Abacus's version + "version": "1.4.0", // Abacus's version "shard": "boujee-coorgi", // Handler shard "uptime": "1h23m45s" // shard uptime diff --git a/main.go b/main.go index 0c351d8..608aa0a 100644 --- a/main.go +++ b/main.go @@ -31,7 +31,7 @@ import ( const ( DocsUrl string = "https://jasoncameron.dev/abacus/" - Version string = "1.3.3" + Version string = "1.4.0" ) var ( diff --git a/middleware/sse.go b/middleware/sse.go index 4c2c049..b0334fa 100644 --- a/middleware/sse.go +++ b/middleware/sse.go @@ -1,6 +1,8 @@ package middleware -import "github.com/gin-gonic/gin" +import ( + "github.com/gin-gonic/gin" +) func SSEMiddleware() gin.HandlerFunc { return func(c *gin.Context) { @@ -8,6 +10,7 @@ func SSEMiddleware() gin.HandlerFunc { c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") // Disable proxy buffering c.Next() } } diff --git a/routes.go b/routes.go index 6d2fb8a..a7a9516 100644 --- a/routes.go +++ b/routes.go @@ -3,6 +3,7 @@ package main import ( "context" "errors" + "fmt" "io" "log" "math" @@ -38,95 +39,95 @@ func StreamValueView(c *gin.Context) { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Initialize client channel with buffer to prevent deadlock - clientChan := make(chan int, 5) // Buffer a few values to prevent blocking + // Initialize client channel with a buffer to prevent blocking + clientChan := make(chan int, 10) - // Create a cancellable context for this client - ctx, cancel := context.WithCancel(c.Request.Context()) - defer cancel() // Ensure context is always canceled + // Create a context that's canceled when the client disconnects + ctx := c.Request.Context() - // Client registration + // Add this client to the event server for this specific key utils.ValueEventServer.NewClients <- utils.KeyClientPair{ Key: dbKey, Client: clientChan, } - // Use a proper mutex-protected flag to track cleanup state - var cleanupOnce sync.Once - cleanup := func() { - cleanupOnce.Do(func() { - log.Printf("Cleaning up client for key %s", dbKey) + // Track if cleanup has been done + var cleanupDone bool + var cleanupMutex sync.Mutex - // Signal that this client is closed + // Ensure client is always removed when handler exits + defer func() { + cleanupMutex.Lock() + if !cleanupDone { + cleanupDone = true + cleanupMutex.Unlock() + + // Signal the event server to remove this client select { - case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{ - Key: dbKey, - Client: clientChan, - }: - // Successfully sent close signal + case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{Key: dbKey, Client: clientChan}: + // Successfully sent cleanup signal case <-time.After(500 * time.Millisecond): - log.Printf("Warning: Timed out sending client closure signal for %s", dbKey) + // Timed out waiting to send cleanup signal + log.Printf("Warning: Timed out sending cleanup signal for %s", dbKey) } + } else { + cleanupMutex.Unlock() + } + }() - // Use a separate goroutine to drain any remaining messages - // This prevents blocking the cleanup function - go func() { - timeout := time.NewTimer(1 * time.Second) - defer timeout.Stop() - - // Drain any pending messages - for { - select { - case _, ok := <-clientChan: - if !ok { - return // Channel already closed - } - // Discard message - case <-timeout.C: - // Safety timeout - return - } - } - }() - }) - } - - // Ensure cleanup runs when handler exits - defer cleanup() - - // Monitor for client disconnection + // Monitor for client disconnection in a separate goroutine go func() { - select { - case <-ctx.Done(): // Context done = client disconnected or request canceled - cleanup() + <-ctx.Done() // Wait for context cancellation (client disconnected) + + cleanupMutex.Lock() + if !cleanupDone { + cleanupDone = true + cleanupMutex.Unlock() + + log.Printf("Client disconnected for key %s, cleaning up", dbKey) + + // Signal the event server to remove this client + select { + case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{Key: dbKey, Client: clientChan}: + // Successfully sent cleanup signal + case <-time.After(500 * time.Millisecond): + // Timed out waiting to send cleanup signal + log.Printf("Warning: Timed out sending cleanup signal for %s after disconnect", dbKey) + } + } else { + cleanupMutex.Unlock() } }() // Send initial value initialVal := Client.Get(context.Background(), dbKey).Val() if count, err := strconv.Atoi(initialVal); err == nil { - c.SSEvent("message", map[string]int{"value": count}) + // Keep your exact format + _, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count)) + if err != nil { + log.Printf("Error writing to client: %v", err) + return + } c.Writer.Flush() } - // Stream updates with clear error handling + // Stream updates c.Stream(func(w io.Writer) bool { select { case <-ctx.Done(): - log.Printf("Client context done for key %s", dbKey) return false - case count, ok := <-clientChan: if !ok { - log.Printf("Client channel closed for key %s", dbKey) return false } - - // Use SSEvent for consistent formatting - c.SSEvent("message", map[string]int{"value": count}) - + // Keep your exact format + _, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count)) + if err != nil { + log.Printf("Error writing to client: %v", err) + return false + } + c.Writer.Flush() return true } }) diff --git a/routes_test.go b/routes_test.go index 33e07b0..c7c7a30 100644 --- a/routes_test.go +++ b/routes_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" "time" @@ -482,6 +483,10 @@ func TestStreamValueView(t *testing.T) { w := newMockResponseWriter() req, _ := http.NewRequest("GET", "/stream/test/stream_key", nil) + // Create a new context with cancellation + requestCtx, cancelFunc := context.WithCancel(req.Context()) + req = req.WithContext(requestCtx) + // Channel to signal test completion done := make(chan struct{}) go func() { @@ -489,10 +494,21 @@ func TestStreamValueView(t *testing.T) { r.ServeHTTP(w, req) }() - // Give the stream some time to start - time.Sleep(100 * time.Millisecond) + // Wait for initial response + waitForContains := func(w *mockResponseWriter, expected string, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if strings.Contains(w.Body.String(), expected) { + return true + } + time.Sleep(10 * time.Millisecond) + } + return false + } - assert.Contains(t, w.Header().Get("Content-Type"), "text/event-stream") + if !waitForContains(w, "data:", 500*time.Millisecond) { + t.Fatal("Initial SSE connection not established in time") + } // Hit the key to generate updates hitReq, _ := http.NewRequest("GET", "/hit/test/stream_key", nil) @@ -500,17 +516,28 @@ func TestStreamValueView(t *testing.T) { // Trigger updates hitW := httptest.NewRecorder() r.ServeHTTP(hitW, hitReq) - time.Sleep(50 * time.Millisecond) // Allow the stream to process - assert.Contains(t, w.Body.String(), "data: {\"value\":1}\n\n") - r.ServeHTTP(hitW, hitReq) // Hit it again - time.Sleep(50 * time.Millisecond) // Allow the stream to process - assert.Contains(t, w.Body.String(), "data: {\"value\":2}\n\n") + // Check for value 1 with timeout + if !waitForContains(w, "data: {\"value\":1}\n\n", 500*time.Millisecond) { + t.Fatal("Did not receive first update in time") + } + + r.ServeHTTP(hitW, hitReq) // Hit it again + + // Check for value 2 with timeout + if !waitForContains(w, "data: {\"value\":2}\n\n", 500*time.Millisecond) { + t.Fatal("Did not receive second update in time") + } // Signal the stream to stop + cancelFunc() + + // Wait for goroutine to finish with timeout select { case <-done: - case <-time.After(1 * time.Second): // Ensure test doesn't hang forever + // Test completed successfully + case <-time.After(1 * time.Second): + t.Fatal("Test timed out waiting for stream to close") } }) } diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..2cd89d8 --- /dev/null +++ b/stream_test.go @@ -0,0 +1,379 @@ +package main + +import ( + "bufio" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/jasonlovesdoggo/abacus/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockCloseNotifier implements http.CloseNotifier for testing +type MockResponseRecorder struct { + *httptest.ResponseRecorder + closeNotify chan bool +} + +// NewMockResponseRecorder creates a new response recorder with CloseNotify support +func NewMockResponseRecorder() *MockResponseRecorder { + return &MockResponseRecorder{ + ResponseRecorder: httptest.NewRecorder(), + closeNotify: make(chan bool, 1), + } +} + +// CloseNotify implements http.CloseNotifier +func (m *MockResponseRecorder) CloseNotify() <-chan bool { + return m.closeNotify +} + +// Close simulates a client disconnection +func (m *MockResponseRecorder) Close() { + select { + case m.closeNotify <- true: + // Signal sent + default: + // Channel already has a value or is closed + } +} + +// TestStreamBasicFunctionality tests that the stream endpoint correctly +// sends events when values are updated +func TestStreamBasicFunctionality(t *testing.T) { + gin.SetMode(gin.TestMode) + router := setupTestRouter() + + // Create a counter first + createResp := httptest.NewRecorder() + createReq, _ := http.NewRequest("POST", "/create/test/stream-test", nil) + router.ServeHTTP(createResp, createReq) + assert.Equal(t, http.StatusCreated, createResp.Code) + + // For streaming tests, we need a real HTTP server + server := httptest.NewServer(router) + defer server.Close() + + // Use a real HTTP client to connect to the server + client := &http.Client{ + Timeout: 5 * time.Second, + } + + req, err := http.NewRequest("GET", server.URL+"/stream/test/stream-test", nil) + require.NoError(t, err) + + req.Header.Set("Accept", "text/event-stream") + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Channel to collect received events + events := make(chan string, 10) + done := make(chan struct{}) + + // Process the SSE stream + go func() { + defer close(done) + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + select { + case events <- line: + // Event sent + case <-time.After(100 * time.Millisecond): + // Buffer full, drop event + t.Logf("Event buffer full, dropped: %s", line) + } + } + } + if err := scanner.Err(); err != nil { + t.Logf("Scanner error: %v", err) + } + }() + + // Wait for initial value + select { + case event := <-events: + assert.True(t, strings.HasPrefix(event, "data: {\"value\":")) + case <-time.After(1 * time.Second): + t.Fatal("Timeout waiting for initial event") + } + + // Hit the counter to increment its value + hitResp, err := client.Get(server.URL + "/hit/test/stream-test") + require.NoError(t, err) + hitResp.Body.Close() + assert.Equal(t, http.StatusOK, hitResp.StatusCode) + + // Check that we got an update event + select { + case event := <-events: + assert.True(t, strings.HasPrefix(event, "data: {\"value\":")) + + // Extract the value + value := extractValueFromEvent(event) + assert.Equal(t, 1, value) + case <-time.After(1 * time.Second): + t.Fatal("Timeout waiting for update event") + } + + // Close connection + resp.Body.Close() + + // Give some time for cleanup + time.Sleep(500 * time.Millisecond) + + // Verify proper cleanup + clientCount := countClientsForKey("K:test:stream-test") + assert.Equal(t, 0, clientCount, "Client wasn't properly cleaned up after disconnection") +} + +// TestMultipleClients tests multiple clients connecting to the same stream +func TestMultipleClients(t *testing.T) { + gin.SetMode(gin.TestMode) + router := setupTestRouter() + + // Create a counter + createResp := httptest.NewRecorder() + createReq, _ := http.NewRequest("POST", "/create/test/multi-client", nil) + router.ServeHTTP(createResp, createReq) + assert.Equal(t, http.StatusCreated, createResp.Code) + + // Start a real HTTP server + server := httptest.NewServer(router) + defer server.Close() + + // Number of clients to test + numClients := 3 // Reduced from 5 for faster testing + + // Set up client trackers + type clientState struct { + resp *http.Response + events chan string + done chan struct{} + lastValue int + eventCount int + } + + clients := make([]*clientState, numClients) + + // Start all clients + for i := 0; i < numClients; i++ { + // Create client state + clients[i] = &clientState{ + events: make(chan string, 10), + done: make(chan struct{}), + } + + // Create request + req, err := http.NewRequest("GET", server.URL+"/stream/test/multi-client", nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + // Connect client + client := &http.Client{ + Timeout: 5 * time.Second, + } + resp, err := client.Do(req) + require.NoError(t, err) + clients[i].resp = resp + + // Process events + go func(idx int) { + defer close(clients[idx].done) + scanner := bufio.NewScanner(clients[idx].resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + select { + case clients[idx].events <- line: + // Event sent + default: + // Buffer full, drop event + } + } + } + }(i) + } + + // Give time for all clients to connect + time.Sleep(300 * time.Millisecond) + + // Verify all clients receive initial value + for i := 0; i < numClients; i++ { + select { + case event := <-clients[i].events: + clients[i].lastValue = extractValueFromEvent(event) + clients[i].eventCount++ + case <-time.After(1 * time.Second): + t.Fatalf("Timeout waiting for client %d initial event", i) + } + } + + // Hit the counter several times + for hits := 0; hits < 3; hits++ { + client := &http.Client{} + resp, err := client.Get(server.URL + "/hit/test/multi-client") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Give time for events to propagate + time.Sleep(100 * time.Millisecond) + } + + // Verify all clients received the updates + for i := 0; i < numClients; i++ { + // Drain all events + timeout := time.After(500 * time.Millisecond) + draining := true + + for draining { + select { + case event := <-clients[i].events: + clients[i].lastValue = extractValueFromEvent(event) + clients[i].eventCount++ + case <-timeout: + draining = false + } + } + + // Each client should have received at least 4 events (initial + 3 hits) + assert.GreaterOrEqual(t, clients[i].eventCount, 4, "Client %d didn't receive enough events", i) + assert.Equal(t, 3, clients[i].lastValue, "Client %d has incorrect final value", i) + } + + // Disconnect clients one by one and verify cleanup + for i := 0; i < numClients; i++ { + // Close client connection + clients[i].resp.Body.Close() + + // Give time for cleanup + time.Sleep(200 * time.Millisecond) + + // Verify decreasing client count + clientCount := countClientsForKey("K:test:multi-client") + assert.Equal(t, numClients-(i+1), clientCount, "Client wasn't properly cleaned up after disconnection") + } +} + +// TestStreamConcurrencyStress tests the stream under high concurrency conditions +func TestStreamConcurrencyStress(t *testing.T) { + // Skip in normal testing as this is a long stress test + if testing.Short() { + t.Skip("Skipping stress test in short mode") + } + + gin.SetMode(gin.ReleaseMode) // Reduce logging noise + router := setupTestRouter() + + // Create a counter for stress testing + createResp := httptest.NewRecorder() + createReq, _ := http.NewRequest("POST", "/create/test/stress-test", nil) + router.ServeHTTP(createResp, createReq) + require.Equal(t, http.StatusCreated, createResp.Code) + + // Start a real HTTP server + server := httptest.NewServer(router) + defer server.Close() + + // Test parameters + numClients := 20 // Reduced from 50 for faster testing + clientDuration := 300 * time.Millisecond + + // Start with no clients + initialCount := countClientsForKey("K:test:stress-test") + assert.Equal(t, 0, initialCount) + + // Launch many concurrent clients + var wg sync.WaitGroup + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + // Create client + client := &http.Client{} + + // Create request + req, err := http.NewRequest("GET", server.URL+"/stream/test/stress-test", nil) + if err != nil { + t.Logf("Error creating request: %v", err) + return + } + req.Header.Set("Accept", "text/event-stream") + + // Send request + resp, err := client.Do(req) + if err != nil { + t.Logf("Error connecting: %v", err) + return + } + + // Keep connection open for the duration + time.Sleep(clientDuration) + + // Close connection + resp.Body.Close() + }(i) + + // Stagger client creation slightly + time.Sleep(5 * time.Millisecond) + } + + // Wait for all clients to finish + wg.Wait() + + // Give extra time for any cleanup + time.Sleep(1 * time.Second) + + // Verify all clients were cleaned up + finalCount := countClientsForKey("K:test:stress-test") + assert.Equal(t, 0, finalCount, "Not all clients were cleaned up after stress test") + + // Check we can still connect new clients + client := &http.Client{} + req, err := http.NewRequest("GET", server.URL+"/stream/test/stress-test", nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Give time for connection + time.Sleep(200 * time.Millisecond) + + // Verify new client connected + newCount := countClientsForKey("K:test:stress-test") + assert.Equal(t, 1, newCount, "Failed to connect new client after stress test") + + // Clean up + resp.Body.Close() + time.Sleep(200 * time.Millisecond) +} + +func countClientsForKey(key string) int { + return utils.ValueEventServer.CountClientsForKey(key) +} + +func extractValueFromEvent(event string) int { + // Format is "data: {"value":X}" + jsonStr := strings.TrimPrefix(event, "data: ") + var data struct { + Value int `json:"value"` + } + err := json.Unmarshal([]byte(jsonStr), &data) + if err != nil { + return -1 + } + return data.Value +} diff --git a/utils/sse.go b/utils/sse.go index fb9b72f..b09d155 100644 --- a/utils/sse.go +++ b/utils/sse.go @@ -3,6 +3,7 @@ package utils import ( "log" "sync" + "time" ) type ValueEvent struct { @@ -25,9 +26,10 @@ type KeyClientPair struct { func NewValueEventServer() *ValueEvent { event := &ValueEvent{ - Message: make(chan KeyValue), - NewClients: make(chan KeyClientPair), - ClosedClients: make(chan KeyClientPair), + // Use buffered channels to prevent blocking + Message: make(chan KeyValue, 100), + NewClients: make(chan KeyClientPair, 100), + ClosedClients: make(chan KeyClientPair, 100), TotalClients: make(map[string]map[chan int]bool), } go event.listen() @@ -48,26 +50,85 @@ func (v *ValueEvent) listen() { case closedClient := <-v.ClosedClients: v.Mu.Lock() - delete(v.TotalClients[closedClient.Key], closedClient.Client) - close(closedClient.Client) + if clients, exists := v.TotalClients[closedClient.Key]; exists { + if _, ok := clients[closedClient.Client]; ok { + delete(clients, closedClient.Client) - // Clean up key map if no more clients - if len(v.TotalClients[closedClient.Key]) == 0 { - delete(v.TotalClients, closedClient.Key) + // Close channel safely + close(closedClient.Client) + + log.Printf("Removed client for key %s", closedClient.Key) + + // Clean up key map if no more clients + if len(clients) == 0 { + delete(v.TotalClients, closedClient.Key) + log.Printf("No more clients for key %s, removed key entry", closedClient.Key) + } + } } v.Mu.Unlock() - log.Printf("Removed client for key %s", closedClient.Key) case keyValue := <-v.Message: + // First, get a snapshot of clients under read lock v.Mu.RLock() - for clientChan := range v.TotalClients[keyValue.Key] { - clientChan <- keyValue.Value + clients, exists := v.TotalClients[keyValue.Key] + if !exists || len(clients) == 0 { + v.Mu.RUnlock() + continue + } + + // Create a safe copy of client channels + clientChannels := make([]chan int, 0, len(clients)) + for clientChan := range clients { + clientChannels = append(clientChannels, clientChan) } v.Mu.RUnlock() + + // Send messages without holding the lock + // Track which clients failed to receive + var failedClients []chan int + for _, clientChan := range clientChannels { + select { + case clientChan <- keyValue.Value: + // Message sent successfully + case <-time.After(100 * time.Millisecond): + // Client not responding, mark for removal + failedClients = append(failedClients, clientChan) + } + } + + // Schedule removal of failed clients + for _, failedClient := range failedClients { + select { + case v.ClosedClients <- KeyClientPair{Key: keyValue.Key, Client: failedClient}: + // Client scheduled for removal + default: + // If ClosedClients channel is full, try again later + go func(key string, client chan int) { + time.Sleep(200 * time.Millisecond) + select { + case v.ClosedClients <- KeyClientPair{Key: key, Client: client}: + // Success on retry + default: + log.Printf("Failed to remove client for key %s even after retry", key) + } + }(keyValue.Key, failedClient) + } + } } } } +func (v *ValueEvent) CountClientsForKey(key string) int { + v.Mu.RLock() + defer v.Mu.RUnlock() + + if clients, exists := v.TotalClients[key]; exists { + return len(clients) + } + return 0 +} + // Global event server var ValueEventServer *ValueEvent @@ -77,10 +138,12 @@ func init() { // When you want to update a value and notify clients for a specific key func SetStream(dbKey string, newValue int) { - // Broadcast the new value only to clients listening to this specific key - ValueEventServer.Message <- KeyValue{ - Key: dbKey, - Value: newValue, + // Use a non-blocking send with default case to prevent blocking + select { + case ValueEventServer.Message <- KeyValue{Key: dbKey, Value: newValue}: + // Message sent successfully + default: + log.Printf("Warning: Message channel full, update for key %s dropped", dbKey) } } @@ -100,8 +163,11 @@ func CloseStream(dbKey string) { ValueEventServer.Mu.Unlock() // Now close the channels after releasing the lock - // This ensures we're not holding the lock while performing potentially blocking operations for _, ch := range channelsToClose { close(ch) } + + if len(channelsToClose) > 0 { + log.Printf("Closed all streams for key %s (%d clients)", dbKey, len(channelsToClose)) + } } diff --git a/utils/stats.go b/utils/stats.go index 2e5e757..58886bf 100644 --- a/utils/stats.go +++ b/utils/stats.go @@ -150,6 +150,9 @@ func (sm *StatManager) monitorHealth() { // Log that we're skipping save despite high total because buffer is empty log.Printf("Total count high (%d/%d) but buffer is empty. Skipping unnecessary save operation.", snapshot.Total, totalWarningThreshold) + } else { + log.Printf("Buffer (%d/%d) and total count (%d/%d) are within acceptable limits. Skipping save operation.", + len(sm.buffer), batchSize, snapshot.Total, totalWarningThreshold) } } }