A local-first private AI assistant for everyday use. Runs on-device models with encrypted P2P sync, and supports sharing chats publicly on ATProto.
10
fork

Configure Feed

Select the types of activity you want to include in your feed.

fix: stopped using uv to start dev server, due to flat layout issues

madclaws ad404b17 117becde

+9 -1135
+1 -1
justfile
··· 12 12 cargo test 13 13 14 14 serve: 15 - uv run --project server python -m server.main 15 + server/.venv/bin/python3 -m server.main 16 16 17 17 bundle: 18 18 ./scripts/bundler.sh
-1
server/api.py
··· 61 61 if request.stream: 62 62 result = ({}, "") 63 63 if request.python_code: 64 - create_memory_if_not_exists() 65 64 result = execute_sandboxed_code( 66 65 code=request.python_code, 67 66 allowed_path=_memory_path,
+1 -1
server/backend/mlx_runner.py
··· 249 249 250 250 # Model-specific handling based on known patterns 251 251 # Use reasoning_utils for reasoning model detection and patterns 252 - from .reasoning_utils import ReasoningExtractor 252 + from ..reasoning_utils import ReasoningExtractor 253 253 254 254 if hasattr(self.tokenizer, "name_or_path"): 255 255 name_or_path = str(getattr(self.tokenizer, "name_or_path", "")).lower()
-1128
server/mlx_runner.py
··· 1 - """ 2 - Enhanced MLX model runner with direct API integration. 3 - Provides ollama-like run experience with streaming and interactive chat. 4 - """ 5 - 6 - import sys 7 - import json 8 - import os 9 - import time 10 - from collections.abc import Iterator 11 - from pathlib import Path 12 - from typing import Dict, Optional 13 - 14 - if sys.platform == "darwin": 15 - import mlx.core as mx 16 - else: 17 - mx = None 18 - from mlx_lm import load 19 - from mlx_lm.generate import generate_step 20 - from mlx_lm.sample_utils import make_repetition_penalty, make_sampler 21 - 22 - from ..reasoning_utils import ReasoningExtractor, StreamingReasoningParser 23 - 24 - 25 - def get_model_context_length(model_path: str) -> int: 26 - """Extract max_position_embeddings from model config. 27 - 28 - Args: 29 - model_path: Path to the MLX model directory 30 - 31 - Returns: 32 - Maximum context length for the model (defaults to 4096 if not found) 33 - """ 34 - config_path = os.path.join(model_path, "config.json") 35 - 36 - try: 37 - with open(config_path) as f: 38 - config = json.load(f) 39 - 40 - # Try various common config keys for context length 41 - context_keys = [ 42 - "max_position_embeddings", 43 - "n_positions", 44 - "context_length", 45 - "max_sequence_length", 46 - "seq_len", 47 - ] 48 - 49 - for key in context_keys: 50 - if key in config: 51 - return config[key] 52 - 53 - # If no context length found, return reasonable default 54 - return 4096 55 - 56 - except (FileNotFoundError, json.JSONDecodeError, KeyError): 57 - # Return default if config can't be read 58 - return 4096 59 - 60 - 61 - class MLXRunner: 62 - """Direct MLX model runner with streaming and interactive capabilities.""" 63 - 64 - def __init__( 65 - self, model_path: str, adapter_path: Optional[str] = None, verbose: bool = False 66 - ): 67 - """Initialize the runner with a model. 68 - 69 - Args: 70 - model_path: Path to the MLX model directory 71 - adapter_path: Optional path to LoRA adapter 72 - verbose: Show detailed output 73 - """ 74 - self.model_path = Path(model_path) 75 - self.adapter_path = adapter_path 76 - self.model = None 77 - self.tokenizer = None 78 - self._memory_baseline = None 79 - self._stop_tokens = None # Will be populated from tokenizer 80 - self._message_end_tokens = None # Message-end tokens (e.g., <|end|> for MXFP4) 81 - self._chat_stop_tokens = None # Chat-specific stop tokens 82 - self._context_length = None # Will be populated from model config 83 - self._is_reasoning_model = False # Whether model uses reasoning (MXFP4) 84 - self._reasoning_start = None # Reasoning start marker 85 - self._reasoning_end = None # Reasoning end marker 86 - self._final_start = None # Final answer start marker 87 - self.verbose = verbose 88 - self._model_loaded = False 89 - self._context_entered = False # Prevent nested context usage 90 - 91 - def __enter__(self): 92 - """Context manager entry - loads the model.""" 93 - if self._context_entered: 94 - raise RuntimeError( 95 - "MLXRunner context manager cannot be entered multiple times" 96 - ) 97 - 98 - self._context_entered = True 99 - try: 100 - self.load_model() 101 - return self 102 - except Exception: 103 - # If load_model fails, ensure cleanup happens 104 - self._context_entered = False 105 - self.cleanup() 106 - raise 107 - 108 - def __exit__(self, exc_type, exc_val, exc_tb): 109 - """Context manager exit - cleans up the model.""" 110 - self._context_entered = False 111 - self.cleanup() 112 - return False # Don't suppress exceptions 113 - 114 - def load_model(self): 115 - """Load the MLX model and tokenizer.""" 116 - if self._model_loaded: 117 - if self.verbose: 118 - print("Model already loaded, skipping...") 119 - return 120 - 121 - if self.verbose: 122 - print(f"Loading model from {self.model_path}...") 123 - start_time = time.time() 124 - 125 - # Capture baseline memory before loading 126 - try: 127 - mx.clear_cache() 128 - except Exception: 129 - pass # Continue even if cache clear fails 130 - self._memory_baseline = mx.get_active_memory() / 1024**3 131 - 132 - try: 133 - # Load model and tokenizer 134 - self.model, self.tokenizer = load( 135 - str(self.model_path), adapter_path=self.adapter_path 136 - ) 137 - 138 - load_time = time.time() - start_time 139 - current_memory = mx.get_active_memory() / 1024**3 140 - model_memory = current_memory - self._memory_baseline 141 - 142 - if self.verbose: 143 - print(f"Model loaded in {load_time:.1f}s") 144 - print( 145 - f"Memory: {model_memory:.1f}GB model, {current_memory:.1f}GB total" 146 - ) 147 - 148 - # Extract stop tokens from tokenizer 149 - self._extract_stop_tokens() 150 - 151 - # Extract context length from model config 152 - self._context_length = get_model_context_length(str(self.model_path)) 153 - 154 - if self.verbose: 155 - print(f"Model context length: {self._context_length} tokens") 156 - 157 - self._model_loaded = True 158 - 159 - except Exception as e: 160 - # Ensure partial state is cleaned up on failure 161 - self.model = None 162 - self.tokenizer = None 163 - self._stop_tokens = None 164 - self._model_loaded = False 165 - # Clear any memory that might have been allocated 166 - mx.clear_cache() 167 - raise RuntimeError( 168 - f"Failed to load model from {self.model_path}: {e}" 169 - ) from e 170 - 171 - def _extract_stop_tokens(self): 172 - """Extract stop tokens from the tokenizer dynamically. 173 - 174 - This method identifies ALL tokens that should stop generation: 175 - 1. Official EOS token from tokenizer config 176 - 2. Message-end tokens from training (e.g., <|end|> for MXFP4) 177 - 3. Common stop tokens across models 178 - """ 179 - self._stop_tokens = set() 180 - self._message_end_tokens = ( 181 - set() 182 - ) # Tokens that end messages but not conversations 183 - 184 - # Primary source: eos_token 185 - eos_token = getattr(self.tokenizer, "eos_token", None) 186 - if eos_token: 187 - self._stop_tokens.add(eos_token) 188 - 189 - # Also check pad_token if it's different from eos_token 190 - pad_token = getattr(self.tokenizer, "pad_token", None) 191 - if pad_token and pad_token != eos_token: 192 - self._stop_tokens.add(pad_token) 193 - 194 - # Check additional_special_tokens 195 - if hasattr(self.tokenizer, "additional_special_tokens"): 196 - for token in self.tokenizer.additional_special_tokens: 197 - if token and isinstance(token, str): 198 - # Only add tokens that look like stop/end tokens 199 - if any( 200 - keyword in token.lower() for keyword in ["end", "stop", "eot"] 201 - ): 202 - self._stop_tokens.add(token) 203 - 204 - # MLX-LM 0.27.0+: Extract tokens from added_tokens_decoder (comprehensive source) 205 - if hasattr(self.tokenizer, "added_tokens_decoder"): 206 - for _token_id, token_info in self.tokenizer.added_tokens_decoder.items(): 207 - if isinstance(token_info, dict) and "content" in token_info: 208 - token_content = token_info["content"] 209 - if token_content and isinstance(token_content, str): 210 - token_lower = token_content.lower() 211 - 212 - # NOTE: <|end|> is NOT a stop token for MXFP4 models! 213 - # It's a separator between reasoning and final answer 214 - if token_content == "<|end|>": 215 - self._message_end_tokens.add(token_content) 216 - # Do NOT add as stop token - let model continue to final answer 217 - 218 - # Look for tokens that could be end/stop tokens 219 - # Expanded patterns for MLX-LM 0.27.0 token varieties 220 - # EXCLUDE <|end|> for MXFP4 models as it's a reasoning separator 221 - end_patterns = [ 222 - "stop", 223 - "eot", 224 - "return", 225 - "finish", 226 - "done", 227 - "im_end", 228 - ] 229 - if any(pattern in token_lower for pattern in end_patterns): 230 - # Decide if it's a message-end or conversation-end token 231 - if "im_end" in token_lower: 232 - self._message_end_tokens.add(token_content) 233 - self._stop_tokens.add(token_content) 234 - # Special handling for 'end' pattern - more selective 235 - elif "end" in token_lower and token_content != "<|end|>": 236 - # Only add non-<|end|> tokens with 'end' in them 237 - self._stop_tokens.add(token_content) 238 - 239 - # Special case: control tokens in |..| format 240 - elif token_content.startswith("<|") and token_content.endswith( 241 - "|>" 242 - ): 243 - # Be inclusive with control tokens that might stop generation 244 - if any( 245 - pattern in token_lower 246 - for pattern in ["end", "return", "stop", "finish"] 247 - ): 248 - self._stop_tokens.add(token_content) 249 - 250 - # Model-specific handling based on known patterns 251 - # Use reasoning_utils for reasoning model detection and patterns 252 - from .reasoning_utils import ReasoningExtractor 253 - 254 - if hasattr(self.tokenizer, "name_or_path"): 255 - name_or_path = str(getattr(self.tokenizer, "name_or_path", "")).lower() 256 - model_type = ReasoningExtractor.detect_model_type(name_or_path) 257 - 258 - if model_type: 259 - # This is a reasoning model 260 - self._is_reasoning_model = True 261 - 262 - # Get patterns from reasoning_utils 263 - if model_type in ReasoningExtractor.PATTERNS: 264 - markers = ReasoningExtractor.PATTERNS[model_type]["markers"] 265 - self._reasoning_start = markers.get("reasoning_start") 266 - self._reasoning_end = markers.get("reasoning_end") 267 - self._final_start = markers.get("final_marker") 268 - 269 - # For reasoning models, remove reasoning_end from stop tokens 270 - if self._reasoning_end: 271 - self._stop_tokens.discard(self._reasoning_end) 272 - 273 - # Add proper stop token for this model type 274 - if model_type == "gpt-oss": 275 - if "<|return|>" not in self._stop_tokens: 276 - self._stop_tokens.add("<|return|>") 277 - else: 278 - self._is_reasoning_model = False 279 - else: 280 - self._is_reasoning_model = False 281 - 282 - # Add common stop tokens that might not be in special tokens 283 - common_stop_tokens = {"</s>", "<|endoftext|>", "<|im_end|>", "<|eot_id|>"} 284 - 285 - # Add chat-specific stop tokens to prevent model self-conversations 286 - # Based on our _format_conversation() format: "Human:" and "Assistant:" 287 - # Also include "You:" as models might use UI-visible format 288 - # Include single-letter variations (H:, A:, Y:) that some models use 289 - chat_stop_tokens = { 290 - "\nHuman:", 291 - "\nAssistant:", 292 - "\nYou:", 293 - "\n\nHuman:", 294 - "\n\nAssistant:", 295 - "\n\nYou:", 296 - "\nH:", 297 - "\nA:", 298 - "\nY:", # Single-letter variations 299 - "\n\nH:", 300 - "\n\nA:", 301 - "\n\nY:", 302 - } 303 - 304 - # Add common stop tokens only if they decode to themselves (i.e., they're single tokens) 305 - for token in common_stop_tokens: 306 - try: 307 - # Try to encode and decode to verify it's a real single token 308 - ids = self.tokenizer.encode(token, add_special_tokens=False) 309 - if ids and len(ids) == 1: # Single token ID means it's a special token 310 - decoded = self.tokenizer.decode(ids) 311 - if decoded == token: 312 - self._stop_tokens.add(token) 313 - except: 314 - pass 315 - 316 - # Store chat stop tokens separately - only used in interactive chat mode 317 - # This prevents stopping mid-story when user asks for dialogues 318 - self._chat_stop_tokens = list(chat_stop_tokens) 319 - 320 - # Remove any None values 321 - self._stop_tokens.discard(None) 322 - self._message_end_tokens.discard(None) 323 - 324 - # Convert to list for easier use 325 - self._stop_tokens = list(self._stop_tokens) 326 - self._message_end_tokens = list(self._message_end_tokens) 327 - 328 - if self.verbose: 329 - if self._stop_tokens: 330 - print(f"Stop tokens: {self._stop_tokens}") 331 - if self._message_end_tokens: 332 - print(f"Message end tokens: {self._message_end_tokens}") 333 - 334 - def cleanup(self): 335 - """Clean up model resources and clear GPU memory. 336 - 337 - This method is safe to call multiple times and handles partial state cleanup. 338 - """ 339 - if self.verbose and self._model_loaded: 340 - memory_before = mx.get_active_memory() / 1024**3 341 - print(f"Cleaning up model (memory before: {memory_before:.1f}GB)...") 342 - 343 - # Always clean up, even if model wasn't fully loaded 344 - self.model = None 345 - self.tokenizer = None 346 - self._stop_tokens = None 347 - self._message_end_tokens = None 348 - self._chat_stop_tokens = None 349 - self._context_length = None 350 - self._is_reasoning_model = False 351 - self._reasoning_start = None 352 - self._reasoning_end = None 353 - self._final_start = None 354 - self._model_loaded = False 355 - 356 - # Force garbage collection and clear MLX cache 357 - import gc 358 - 359 - gc.collect() 360 - try: 361 - mx.clear_cache() 362 - except Exception: 363 - pass # Continue cleanup even if cache clear fails 364 - 365 - if self.verbose: 366 - memory_after = mx.get_active_memory() / 1024**3 367 - if "memory_before" in locals(): 368 - memory_freed = memory_before - memory_after 369 - print( 370 - f"Cleanup complete (memory after: {memory_after:.1f}GB, freed: {memory_freed:.1f}GB)" 371 - ) 372 - else: 373 - print(f"Cleanup complete (memory after: {memory_after:.1f}GB)") 374 - 375 - def get_effective_max_tokens( 376 - self, requested_tokens: Optional[int], interactive: bool = False 377 - ) -> int: 378 - """Get effective max tokens based on model context and usage mode. 379 - 380 - Args: 381 - requested_tokens: The requested max tokens (None if user didn't specify --max-tokens) 382 - interactive: True if this is interactive mode (gets full context length) 383 - 384 - Returns: 385 - Effective max tokens to use 386 - """ 387 - if not self._context_length: 388 - # Fallback when context length is unknown 389 - fallback = 4096 if interactive else 2048 390 - if self.verbose: 391 - if requested_tokens is None: 392 - print( 393 - f"[WARNING] Model context length unknown, using fallback: {fallback} tokens" 394 - ) 395 - else: 396 - print( 397 - f"[WARNING] Model context length unknown, using user specified: {requested_tokens} tokens" 398 - ) 399 - return requested_tokens if requested_tokens is not None else fallback 400 - 401 - if interactive: 402 - if requested_tokens is None: 403 - # User didn't specify --max-tokens: use full model context 404 - return self._context_length 405 - else: 406 - # User specified --max-tokens explicitly: respect their choice but cap at context 407 - return min(requested_tokens, self._context_length) 408 - else: 409 - # Server/batch mode uses half context length for DoS protection 410 - server_limit = self._context_length // 2 411 - return min(requested_tokens or server_limit, server_limit) 412 - 413 - def generate_streaming( 414 - self, 415 - prompt: str, 416 - max_tokens: int = 500, 417 - temperature: float = 0.7, 418 - top_p: float = 0.9, 419 - repetition_penalty: float = 1.1, 420 - repetition_context_size: int = 20, 421 - use_chat_template: bool = True, 422 - use_chat_stop_tokens: bool = False, 423 - interactive: bool = False, 424 - hide_reasoning: bool = False, 425 - ) -> Iterator[str]: 426 - """Generate text with streaming output. 427 - 428 - Args: 429 - prompt: Input prompt 430 - max_tokens: Maximum tokens to generate 431 - temperature: Sampling temperature 432 - top_p: Top-p sampling parameter 433 - repetition_penalty: Penalty for repeated tokens 434 - repetition_context_size: Context size for repetition penalty 435 - use_chat_template: Apply tokenizer's chat template if available 436 - use_chat_stop_tokens: Include chat turn markers as stop tokens (for interactive mode) 437 - interactive: True if this is interactive mode (affects token limits) 438 - 439 - Yields: 440 - Generated tokens as they are produced 441 - """ 442 - if not self.model or not self.tokenizer: 443 - raise RuntimeError("Model not loaded. Call load_model() first.") 444 - 445 - # Initialize reasoning parser if this is a reasoning model 446 - reasoning_parser = None 447 - if self._is_reasoning_model: 448 - model_type = ReasoningExtractor.detect_model_type( 449 - getattr(self.tokenizer, "name_or_path", "") or "" 450 - ) 451 - reasoning_parser = StreamingReasoningParser( 452 - model_type, hide_reasoning=hide_reasoning 453 - ) 454 - 455 - # Apply context-aware token limits 456 - effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive) 457 - 458 - # Apply chat template if available and requested 459 - if ( 460 - use_chat_template 461 - and hasattr(self.tokenizer, "chat_template") 462 - and self.tokenizer.chat_template 463 - ): 464 - messages = [{"role": "user", "content": prompt}] 465 - formatted_prompt = self.tokenizer.apply_chat_template( 466 - messages, tokenize=False, add_generation_prompt=True 467 - ) 468 - else: 469 - formatted_prompt = prompt 470 - 471 - # Tokenize the prompt 472 - prompt_tokens = self.tokenizer.encode(formatted_prompt) 473 - prompt_array = mx.array(prompt_tokens) 474 - 475 - # Track generation metrics 476 - start_time = time.time() 477 - tokens_generated = 0 478 - 479 - # Create sampler with our parameters 480 - sampler = make_sampler(temp=temperature, top_p=top_p) 481 - 482 - # Create repetition penalty processor if needed 483 - logits_processors = [] 484 - if repetition_penalty > 1.0: 485 - logits_processors.append( 486 - make_repetition_penalty(repetition_penalty, repetition_context_size) 487 - ) 488 - 489 - # Generate tokens one by one for streaming 490 - generator = generate_step( 491 - prompt=prompt_array, 492 - model=self.model, 493 - max_tokens=effective_max_tokens, 494 - sampler=sampler, 495 - logits_processors=logits_processors if logits_processors else None, 496 - ) 497 - 498 - # Collect tokens and yield text 499 - generated_tokens = [] 500 - previous_decoded = "" 501 - accumulated_response = "" # Track full response for stop token detection 502 - 503 - # Keep a sliding window of recent tokens for context 504 - context_window = 10 # Decode last N tokens for proper spacing 505 - 506 - for token, _ in generator: 507 - # Token might be an array or an int 508 - token_id = token.item() if hasattr(token, "item") else token 509 - generated_tokens.append(token_id) 510 - 511 - # Use a sliding window approach for efficiency 512 - start_idx = max(0, len(generated_tokens) - context_window) 513 - window_tokens = generated_tokens[start_idx:] 514 - 515 - # Decode the window 516 - window_text = self.tokenizer.decode(window_tokens) 517 - 518 - # Figure out what's new 519 - if start_idx == 0: 520 - # We're still within the context window 521 - if window_text.startswith(previous_decoded): 522 - new_text = window_text[len(previous_decoded) :] 523 - else: 524 - new_text = self.tokenizer.decode([token_id]) 525 - previous_decoded = window_text 526 - else: 527 - # We're beyond the context window, just decode the last token with context 528 - # This is approximate but should preserve spaces 529 - new_text = self.tokenizer.decode(window_tokens) 530 - if len(window_tokens) > 1: 531 - prefix = self.tokenizer.decode(window_tokens[:-1]) 532 - if new_text.startswith(prefix): 533 - new_text = new_text[len(prefix) :] 534 - else: 535 - new_text = self.tokenizer.decode([token_id]) 536 - 537 - if new_text: 538 - # Update accumulated response for stop token checking 539 - accumulated_response += new_text 540 - 541 - # Filter out stop tokens with priority: native first, then chat fallback 542 - # Check native stop tokens FIRST in accumulated response (highest priority) 543 - native_stop_tokens = self._stop_tokens if self._stop_tokens else [] 544 - for stop_token in native_stop_tokens: 545 - if stop_token in accumulated_response: 546 - # Find the stop token position and yield everything before it 547 - stop_pos = accumulated_response.find(stop_token) 548 - # Calculate what text came before the stop token 549 - text_before_stop = accumulated_response[:stop_pos] 550 - # Calculate how much of that is new (not previously yielded) 551 - previously_yielded_length = len(accumulated_response) - len( 552 - new_text 553 - ) 554 - if len(text_before_stop) > previously_yielded_length: 555 - # Yield only the new part before stop token 556 - new_part_before_stop = text_before_stop[ 557 - previously_yielded_length: 558 - ] 559 - if new_part_before_stop: 560 - if reasoning_parser: 561 - # Process through reasoning parser for formatting 562 - for ( 563 - formatted_token 564 - ) in reasoning_parser.process_token( 565 - new_part_before_stop 566 - ): 567 - yield formatted_token 568 - else: 569 - yield new_part_before_stop 570 - return # Stop generation without yielding stop token 571 - 572 - # Only check chat stop tokens if no native stop token found (fallback) 573 - if use_chat_stop_tokens and self._chat_stop_tokens: 574 - for stop_token in self._chat_stop_tokens: 575 - if stop_token in accumulated_response: 576 - # Find the stop token position and yield everything before it 577 - stop_pos = accumulated_response.find(stop_token) 578 - # Calculate what text came before the stop token 579 - text_before_stop = accumulated_response[:stop_pos] 580 - # Calculate how much of that is new (not previously yielded) 581 - previously_yielded_length = len(accumulated_response) - len( 582 - new_text 583 - ) 584 - if len(text_before_stop) > previously_yielded_length: 585 - # Yield only the new part before stop token 586 - new_part_before_stop = text_before_stop[ 587 - previously_yielded_length: 588 - ] 589 - if new_part_before_stop: 590 - if reasoning_parser: 591 - # Process through reasoning parser for formatting 592 - for ( 593 - formatted_token 594 - ) in reasoning_parser.process_token( 595 - new_part_before_stop 596 - ): 597 - yield formatted_token 598 - else: 599 - yield new_part_before_stop 600 - return # Stop generation without yielding stop token 601 - 602 - # No stop token found, process the new text 603 - if reasoning_parser: 604 - # Process through reasoning parser for formatting 605 - for formatted_token in reasoning_parser.process_token(new_text): 606 - yield formatted_token 607 - else: 608 - # Normal streaming for non-reasoning models 609 - yield new_text 610 - tokens_generated += 1 611 - 612 - # Check for EOS token - don't yield it 613 - if token_id == self.tokenizer.eos_token_id: 614 - break 615 - 616 - # Finalize reasoning parser if used 617 - if reasoning_parser: 618 - yield from reasoning_parser.finalize() 619 - 620 - # Print generation statistics if verbose 621 - if self.verbose: 622 - generation_time = time.time() - start_time 623 - tokens_per_second = ( 624 - tokens_generated / generation_time if generation_time > 0 else 0 625 - ) 626 - print( 627 - f"\n\nGenerated {tokens_generated} tokens in {generation_time:.1f}s ({tokens_per_second:.1f} tokens/s)" 628 - ) 629 - 630 - def generate_batch( 631 - self, 632 - prompt: str, 633 - max_tokens: int = 500, 634 - temperature: float = 0.7, 635 - top_p: float = 0.9, 636 - repetition_penalty: float = 1.1, 637 - repetition_context_size: int = 20, 638 - use_chat_template: bool = True, 639 - interactive: bool = False, 640 - ) -> str: 641 - """Generate text in batch mode (non-streaming). 642 - 643 - Args: 644 - prompt: Input prompt 645 - max_tokens: Maximum tokens to generate 646 - temperature: Sampling temperature 647 - top_p: Top-p sampling parameter 648 - repetition_penalty: Penalty for repeated tokens 649 - repetition_context_size: Context size for repetition penalty 650 - use_chat_template: Apply tokenizer's chat template if available 651 - interactive: True if this is interactive mode (affects token limits) 652 - 653 - Returns: 654 - Generated text 655 - """ 656 - if not self.model or not self.tokenizer: 657 - raise RuntimeError("Model not loaded. Call load_model() first.") 658 - 659 - # Apply context-aware token limits 660 - effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive) 661 - 662 - # Apply chat template if available and requested 663 - if ( 664 - use_chat_template 665 - and hasattr(self.tokenizer, "chat_template") 666 - and self.tokenizer.chat_template 667 - ): 668 - messages = [{"role": "user", "content": prompt}] 669 - formatted_prompt = self.tokenizer.apply_chat_template( 670 - messages, tokenize=False, add_generation_prompt=True 671 - ) 672 - else: 673 - formatted_prompt = prompt 674 - 675 - start_time = time.time() 676 - 677 - # Tokenize the prompt 678 - prompt_tokens = self.tokenizer.encode(formatted_prompt) 679 - prompt_array = mx.array(prompt_tokens) 680 - 681 - # Create sampler with our parameters 682 - sampler = make_sampler(temp=temperature, top_p=top_p) 683 - 684 - # Create repetition penalty processor if needed 685 - logits_processors = [] 686 - if repetition_penalty > 1.0: 687 - logits_processors.append( 688 - make_repetition_penalty(repetition_penalty, repetition_context_size) 689 - ) 690 - 691 - # Generate all tokens at once 692 - generated_tokens = [] 693 - all_tokens = list(prompt_tokens) # Keep prompt for proper decoding 694 - 695 - generator = generate_step( 696 - prompt=prompt_array, 697 - model=self.model, 698 - max_tokens=effective_max_tokens, 699 - sampler=sampler, 700 - logits_processors=logits_processors if logits_processors else None, 701 - ) 702 - 703 - for token, _ in generator: 704 - # Token might be an array or an int 705 - token_id = token.item() if hasattr(token, "item") else token 706 - generated_tokens.append(token_id) 707 - all_tokens.append(token_id) 708 - 709 - # Check for EOS token - don't yield it 710 - if token_id == self.tokenizer.eos_token_id: 711 - break 712 - 713 - # Decode all tokens together for proper spacing 714 - full_response = self.tokenizer.decode(all_tokens) 715 - 716 - # Remove the prompt part 717 - if full_response.startswith(formatted_prompt): 718 - response = full_response[len(formatted_prompt) :] 719 - else: 720 - # Fallback: just decode generated tokens 721 - response = self.tokenizer.decode(generated_tokens) 722 - 723 - # Apply end-token filtering (same logic as streaming mode for Issue #20) 724 - response = self._filter_end_tokens_from_response( 725 - response, use_chat_stop_tokens=False 726 - ) 727 - 728 - # Format reasoning models output 729 - response = self._format_reasoning_response(response) 730 - 731 - generation_time = time.time() - start_time 732 - 733 - # Count tokens for statistics 734 - if self.verbose: 735 - tokens_generated = len(generated_tokens) 736 - tokens_per_second = ( 737 - tokens_generated / generation_time if generation_time > 0 else 0 738 - ) 739 - print( 740 - f"\nGenerated {tokens_generated} tokens in {generation_time:.1f}s ({tokens_per_second:.1f} tokens/s)" 741 - ) 742 - 743 - return response 744 - 745 - def interactive_chat( 746 - self, 747 - system_prompt: Optional[str] = None, 748 - max_tokens: int = 500, 749 - temperature: float = 0.7, 750 - top_p: float = 0.9, 751 - repetition_penalty: float = 1.1, 752 - use_chat_template: bool = True, 753 - ): 754 - """Run an interactive chat session. 755 - 756 - Args: 757 - system_prompt: Optional system prompt to prepend 758 - max_tokens: Maximum tokens per response 759 - temperature: Sampling temperature 760 - top_p: Top-p sampling parameter 761 - repetition_penalty: Penalty for repeated tokens 762 - use_chat_template: Use tokenizer's chat template if available 763 - """ 764 - print("Starting interactive chat. Type 'exit' or 'quit' to end.\n") 765 - 766 - conversation_history = [] 767 - if system_prompt: 768 - conversation_history.append({"role": "system", "content": system_prompt}) 769 - 770 - while True: 771 - try: 772 - # Get user input 773 - user_input = input("You: ").strip() 774 - 775 - if user_input.lower() in ["exit", "quit", "q"]: 776 - print("\nGoodbye!") 777 - break 778 - 779 - if not user_input: 780 - continue 781 - 782 - # Add user message to history 783 - conversation_history.append({"role": "user", "content": user_input}) 784 - 785 - # Format conversation for the model using chat template if available 786 - prompt = self._format_conversation( 787 - conversation_history, use_chat_template=use_chat_template 788 - ) 789 - 790 - # Generate response with streaming 791 - print("\nAssistant: ", end="", flush=True) 792 - 793 - response_tokens = [] 794 - for token in self.generate_streaming( 795 - prompt=prompt, 796 - max_tokens=max_tokens, 797 - temperature=temperature, 798 - top_p=top_p, 799 - repetition_penalty=repetition_penalty, 800 - use_chat_template=False, # Already applied in _format_conversation 801 - use_chat_stop_tokens=True, # Enable chat stop tokens in interactive mode 802 - interactive=True, # Enable full context length for interactive mode 803 - ): 804 - # Stream all tokens directly (already formatted by generate_streaming) 805 - print(token, end="", flush=True) 806 - response_tokens.append(token) 807 - 808 - # Add assistant response to history 809 - assistant_response = "".join(response_tokens).strip() 810 - conversation_history.append( 811 - {"role": "assistant", "content": assistant_response} 812 - ) 813 - 814 - print() # New line after response 815 - 816 - except KeyboardInterrupt: 817 - print("\n\nChat interrupted. Goodbye!") 818 - break 819 - except Exception as e: 820 - print(f"\n[ERROR] {e}") 821 - continue 822 - 823 - def _format_conversation( 824 - self, messages: list, use_chat_template: bool = True 825 - ) -> str: 826 - """Format conversation history into a prompt. 827 - 828 - Uses the tokenizer's chat template if available, otherwise falls back 829 - to the legacy Human:/Assistant: format for compatibility. 830 - 831 - Args: 832 - messages: List of message dictionaries with 'role' and 'content' 833 - use_chat_template: Whether to use chat template if available 834 - 835 - Returns: 836 - Formatted conversation string 837 - """ 838 - # Try to use native chat template if available 839 - if ( 840 - use_chat_template 841 - and hasattr(self.tokenizer, "chat_template") 842 - and self.tokenizer.chat_template 843 - ): 844 - try: 845 - # Apply the tokenizer's chat template 846 - formatted_prompt = self.tokenizer.apply_chat_template( 847 - messages, tokenize=False, add_generation_prompt=True 848 - ) 849 - return formatted_prompt 850 - except Exception as e: 851 - # If chat template fails, fall back to legacy format 852 - if self.verbose: 853 - print(f"[WARNING] Chat template failed, using legacy format: {e}") 854 - 855 - # Legacy format fallback for compatibility 856 - return self._legacy_format_conversation(messages) 857 - 858 - def _legacy_format_conversation(self, messages: list) -> str: 859 - """Legacy conversation formatting for backward compatibility. 860 - 861 - This format was used in earlier versions and remains as a fallback 862 - for models without chat templates. 863 - """ 864 - formatted = [] 865 - 866 - for message in messages: 867 - role = message["role"] 868 - content = message["content"] 869 - 870 - if role == "system": 871 - formatted.append(f"System: {content}") 872 - elif role == "user": 873 - formatted.append(f"Human: {content}") 874 - elif role == "assistant": 875 - formatted.append(f"Assistant: {content}") 876 - 877 - # Add prompt for next assistant response 878 - formatted.append("Assistant:") 879 - 880 - return "\n\n".join(formatted) 881 - 882 - def get_memory_usage(self) -> Dict[str, float]: 883 - """Get current memory usage statistics. 884 - 885 - Returns: 886 - Dictionary with memory statistics in GB 887 - """ 888 - try: 889 - current_memory = mx.get_active_memory() / 1024**3 890 - peak_memory = mx.get_peak_memory() / 1024**3 891 - except Exception: 892 - # Return zeros if memory stats unavailable 893 - current_memory = 0.0 894 - peak_memory = 0.0 895 - 896 - return { 897 - "current_gb": current_memory, 898 - "peak_gb": peak_memory, 899 - "model_gb": ( 900 - current_memory - self._memory_baseline if self._memory_baseline else 0 901 - ), 902 - } 903 - 904 - def _format_reasoning_response(self, response: str) -> str: 905 - """Format response from reasoning models for better readability. 906 - 907 - For MXFP4 models that generate reasoning followed by final answer, 908 - format it nicely for display. 909 - """ 910 - if not self._is_reasoning_model: 911 - return response 912 - 913 - # Check if response contains reasoning markers 914 - if self._reasoning_start in response and self._final_start in response: 915 - # Extract reasoning and final parts 916 - try: 917 - # Split on the reasoning start 918 - before_reasoning, after_start = response.split(self._reasoning_start, 1) 919 - 920 - # Find the reasoning content (until <|end|>) 921 - if self._reasoning_end in after_start: 922 - reasoning_content, after_reasoning = after_start.split( 923 - self._reasoning_end, 1 924 - ) 925 - 926 - # Find the final answer 927 - if self._final_start in after_reasoning: 928 - # Extract everything after final marker 929 - final_parts = after_reasoning.split(self._final_start, 1) 930 - if len(final_parts) > 1: 931 - # Remove the <|channel|>final<|message|> marker 932 - final_answer = final_parts[1].replace( 933 - "<|channel|>final<|message|>", "", 1 934 - ) 935 - 936 - # Format with clear markers for parsing but minimal visual impact 937 - formatted = [] 938 - formatted.append("\n**[Reasoning]**\n") 939 - formatted.append(reasoning_content.strip()) 940 - formatted.append("\n\n---\n\n**[Answer]**\n") 941 - formatted.append(final_answer.strip()) 942 - 943 - return "\n".join(formatted) 944 - except Exception: 945 - # If parsing fails, return original 946 - pass 947 - 948 - # Fallback: just clean up the control tokens 949 - cleaned = response 950 - for marker in [ 951 - "<|channel|>analysis<|message|>", 952 - "<|end|>", 953 - "<|start|>assistant", 954 - "<|channel|>final<|message|>", 955 - "<|return|>", 956 - ]: 957 - cleaned = cleaned.replace(marker, "") 958 - 959 - return cleaned.strip() 960 - 961 - def _filter_end_tokens_from_response( 962 - self, response: str, use_chat_stop_tokens: bool = False 963 - ) -> str: 964 - """Filter end tokens from a complete response (batch mode). 965 - 966 - This method applies the same filtering logic as the streaming mode 967 - to ensure consistent behavior between streaming and non-streaming. 968 - 969 - Args: 970 - response: The complete generated response 971 - use_chat_stop_tokens: Whether to apply chat stop tokens 972 - 973 - Returns: 974 - Response with end tokens filtered out 975 - """ 976 - # Apply native stop token filtering FIRST (highest priority) 977 - native_stop_tokens = self._stop_tokens if self._stop_tokens else [] 978 - for stop_token in native_stop_tokens: 979 - if stop_token in response: 980 - # Find the stop token position and return everything before it 981 - stop_pos = response.find(stop_token) 982 - filtered_response = response[:stop_pos].rstrip() 983 - if self.verbose: 984 - print( 985 - f"[DEBUG] Filtered stop token '{stop_token}' at position {stop_pos}" 986 - ) 987 - return filtered_response 988 - 989 - # Only check chat stop tokens if no native stop token found (fallback) 990 - if use_chat_stop_tokens and self._chat_stop_tokens: 991 - for stop_token in self._chat_stop_tokens: 992 - if stop_token in response: 993 - # Find the stop token position and return everything before it 994 - stop_pos = response.find(stop_token) 995 - return response[:stop_pos] 996 - 997 - # No stop tokens found, return original response 998 - return response 999 - 1000 - 1001 - def get_gpu_status() -> Dict[str, float]: 1002 - """Independent GPU status check - usable from anywhere. 1003 - 1004 - Returns: 1005 - Dictionary with GPU memory statistics in GB 1006 - """ 1007 - return { 1008 - "active_memory_gb": mx.get_active_memory() / 1024**3, 1009 - "peak_memory_gb": mx.get_peak_memory() / 1024**3, 1010 - } 1011 - 1012 - 1013 - def check_memory_available(required_gb: float) -> bool: 1014 - """Pre-flight check before model loading. 1015 - 1016 - Args: 1017 - required_gb: Required memory in GB 1018 - 1019 - Returns: 1020 - True if memory is likely available (conservative estimate) 1021 - """ 1022 - current_memory = mx.get_active_memory() / 1024**3 1023 - 1024 - # Conservative estimate: assume system has at least 8GB unified memory 1025 - # and we should leave some headroom (2GB) for system processes 1026 - estimated_total = 8.0 # This could be improved by detecting actual system memory 1027 - available = estimated_total - current_memory - 2.0 # 2GB headroom 1028 - 1029 - return available >= required_gb 1030 - 1031 - 1032 - def run_model_enhanced( 1033 - model_path: str, 1034 - prompt: Optional[str] = None, 1035 - interactive: bool = False, 1036 - max_tokens: int = 500, 1037 - temperature: float = 0.7, 1038 - top_p: float = 0.9, 1039 - repetition_penalty: float = 1.1, 1040 - stream: bool = True, 1041 - use_chat_template: bool = True, 1042 - hide_reasoning: bool = False, 1043 - verbose: bool = False, 1044 - ) -> Optional[str]: 1045 - """Enhanced run function with direct MLX integration. 1046 - 1047 - Uses context manager pattern for automatic resource cleanup. 1048 - 1049 - Args: 1050 - model_path: Path to the MLX model 1051 - prompt: Input prompt (if None, enters interactive mode) 1052 - interactive: Force interactive mode 1053 - max_tokens: Maximum tokens to generate 1054 - temperature: Sampling temperature 1055 - top_p: Top-p sampling parameter 1056 - repetition_penalty: Penalty for repeated tokens 1057 - stream: Whether to stream output 1058 - 1059 - Returns: 1060 - Generated text (in non-interactive mode) 1061 - """ 1062 - try: 1063 - with MLXRunner(model_path, verbose=verbose) as runner: 1064 - # Interactive mode 1065 - if interactive or prompt is None: 1066 - runner.interactive_chat( 1067 - max_tokens=max_tokens, 1068 - temperature=temperature, 1069 - top_p=top_p, 1070 - repetition_penalty=repetition_penalty, 1071 - use_chat_template=use_chat_template, 1072 - ) 1073 - return None 1074 - 1075 - # Single prompt mode 1076 - if verbose: 1077 - print(f"\nPrompt: {prompt}\n") 1078 - print("Response: ", end="", flush=True) 1079 - 1080 - if stream: 1081 - # Streaming generation 1082 - response_tokens = [] 1083 - try: 1084 - for token in runner.generate_streaming( 1085 - prompt=prompt, 1086 - max_tokens=max_tokens, 1087 - temperature=temperature, 1088 - top_p=top_p, 1089 - repetition_penalty=repetition_penalty, 1090 - use_chat_template=use_chat_template, 1091 - hide_reasoning=hide_reasoning, 1092 - ): 1093 - # Stream all tokens directly (already formatted by generate_streaming) 1094 - print(token, end="", flush=True) 1095 - response_tokens.append(token) 1096 - except KeyboardInterrupt: 1097 - print("\n[INFO] Generation interrupted by user.") 1098 - response = "".join(response_tokens) 1099 - else: 1100 - # Batch generation 1101 - try: 1102 - response = runner.generate_batch( 1103 - prompt=prompt, 1104 - max_tokens=max_tokens, 1105 - temperature=temperature, 1106 - top_p=top_p, 1107 - repetition_penalty=repetition_penalty, 1108 - use_chat_template=use_chat_template, 1109 - ) 1110 - except KeyboardInterrupt: 1111 - print("\n[INFO] Generation interrupted by user.") 1112 - response = "" 1113 - print(response) 1114 - 1115 - # Show memory usage if verbose 1116 - if verbose: 1117 - memory_stats = runner.get_memory_usage() 1118 - print( 1119 - f"\n\nMemory: {memory_stats['model_gb']:.1f}GB model, {memory_stats['current_gb']:.1f}GB total" 1120 - ) 1121 - 1122 - return response 1123 - 1124 - # Note: cleanup happens automatically due to context manager 1125 - 1126 - except Exception as e: 1127 - print(f"\n[ERROR] {e}") 1128 - return None
+6 -3
server/pyproject.toml
··· 11 11 "huggingface-hub>=0.34.0", 12 12 ] 13 13 14 - [build-system] 15 - requires = ["setuptools", "wheel"] 16 - build-backend = "setuptools.build_meta" 14 + # [build-system] 15 + # requires = ["setuptools", "wheel"] 16 + # build-backend = "setuptools.build_meta" 17 + 18 + # [tool.setuptools] 19 + # packages = ["server", "server.backend", "server.mem_agent"]
+1 -1
server/uv.lock
··· 495 495 [[package]] 496 496 name = "server" 497 497 version = "0.1.0" 498 - source = { editable = "." } 498 + source = { virtual = "." } 499 499 dependencies = [ 500 500 { name = "black" }, 501 501 { name = "fastapi" },