a digital entity named phi that roams bsky phi.zzstoatzz.io
2
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 112 lines 4.5 kB view raw
1"""Tests for rate limiting and SSRF protection.""" 2 3import ipaddress 4from unittest.mock import AsyncMock, Mock 5 6from limits import parse as parse_limit 7from limits.storage import MemoryStorage 8from limits.strategies import MovingWindowRateLimiter 9 10 11class TestPerUserRateLimiting: 12 """Test per-user notification rate limiting.""" 13 14 def setup_method(self): 15 self.storage = MemoryStorage() 16 self.limiter = MovingWindowRateLimiter(self.storage) 17 self.limit = parse_limit("3/minute") # low limit for testing 18 19 def test_allows_under_limit(self): 20 for _ in range(3): 21 assert self.limiter.hit(self.limit, "user.bsky.social") 22 23 def test_blocks_over_limit(self): 24 for _ in range(3): 25 self.limiter.hit(self.limit, "user.bsky.social") 26 assert not self.limiter.hit(self.limit, "user.bsky.social") 27 28 def test_independent_per_user(self): 29 for _ in range(3): 30 self.limiter.hit(self.limit, "user-a.bsky.social") 31 # user-a is exhausted 32 assert not self.limiter.hit(self.limit, "user-a.bsky.social") 33 # user-b is unaffected 34 assert self.limiter.hit(self.limit, "user-b.bsky.social") 35 36 37class TestMessageHandlerRateLimiting: 38 """Test that MessageHandler.handle_batch respects per-author rate limits. 39 40 Even though batches now coalesce a poll cycle's notifications into one 41 agent run, rate limiting still applies per-author per-notification — a 42 spammer who chains posts gets each post counted toward their hourly cap. 43 Once the cap is hit, subsequent notifications from that author are filtered 44 out of the batch and the agent run is skipped if nothing remains. 45 """ 46 47 async def test_rate_limited_author_filtered_from_batch(self): 48 from bot.services import message_handler 49 50 handler = Mock() 51 handler.client = Mock() 52 handler.agent = Mock() 53 handler.agent.process_notifications = AsyncMock() 54 handler._build_post_entry = AsyncMock(return_value=None) 55 handler._build_engagement_entry = AsyncMock(return_value=None) 56 handler._build_follow_entry = AsyncMock(return_value=None) 57 handler._maybe_lookup_stranger = AsyncMock(return_value=None) 58 59 def make_notif(): 60 n = Mock() 61 n.reason = "mention" 62 n.author.handle = "spammer.bsky.social" 63 n.uri = "at://example/post/1" 64 return n 65 66 original_limiter = message_handler._limiter 67 original_limit = message_handler._user_limit 68 69 # use a 1/minute limit so the second batch with the same author is blocked 70 test_storage = MemoryStorage() 71 test_limiter = MovingWindowRateLimiter(test_storage) 72 test_limit = parse_limit("1/minute") 73 74 message_handler._limiter = test_limiter 75 message_handler._user_limit = test_limit 76 77 try: 78 # first batch: one notification from the spammer — passes the limiter 79 await message_handler.MessageHandler.handle_batch(handler, [make_notif()]) 80 # _build_post_entry was called (limiter let it through) 81 handler._build_post_entry.assert_called_once() 82 83 handler._build_post_entry.reset_mock() 84 85 # second batch from the same author — filtered out by limiter 86 # nothing was actionable so build/process never get invoked 87 await message_handler.MessageHandler.handle_batch(handler, [make_notif()]) 88 handler._build_post_entry.assert_not_called() 89 handler.agent.process_notifications.assert_not_called() 90 finally: 91 message_handler._limiter = original_limiter 92 message_handler._user_limit = original_limit 93 94 95class TestSSRFProtection: 96 """Test that check_urls blocks private IPs.""" 97 98 def test_private_ips_detected(self): 99 private_ips = ["127.0.0.1", "10.0.0.1", "192.168.1.1", "172.16.0.1", "::1"] 100 for ip_str in private_ips: 101 ip = ipaddress.ip_address(ip_str) 102 assert ip.is_private or ip.is_loopback or ip.is_link_local, ( 103 f"{ip_str} should be blocked" 104 ) 105 106 def test_public_ips_allowed(self): 107 public_ips = ["8.8.8.8", "1.1.1.1", "140.82.121.4"] 108 for ip_str in public_ips: 109 ip = ipaddress.ip_address(ip_str) 110 assert not (ip.is_private or ip.is_loopback or ip.is_link_local), ( 111 f"{ip_str} should be allowed" 112 )