A local-first private AI assistant for everyday use. Runs on-device models with encrypted P2P sync, and supports sharing chats publicly on ATProto.
10
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor: multiple backends for different environments

madclaws 117becde e1a3fd2d

+1267 -171
+9
ATTRIBUTIONS.txt
··· 9 9 10 10 Description: 11 11 Modules regarding mlx from mlx-knife has been used as our starting point and for further references 12 + 13 + 14 + Project: mem-agent-mcp 15 + Author: Dria 16 + Source: https://github.com/firstbatchxyz/mem-agent-mcp 17 + License: Apache-2.0 license 18 + 19 + Description: 20 + Modules regarding mem-agent cli from mem-agent-mcp has been used as our starting point and for further references
+1
server/.gitignore
··· 1 1 __pycache__/ 2 2 *.egg-info/ 3 3 .venv/ 4 + backend/__pycache__
+9 -92
server/api.py
··· 1 1 from fastapi import FastAPI, HTTPException 2 + 3 + from .schemas import ChatMessage, ChatCompletionRequest, StartRequest, downloadRequest 2 4 from .config import SYSTEM_PROMPT 3 5 import logging 4 - import json 5 - import time 6 - import uuid 7 6 import sys 8 - from collections.abc import AsyncGenerator 9 - from typing import Any, Dict, List, Optional, Union 7 + from typing import Optional 10 8 11 9 from fastapi.responses import StreamingResponse 12 10 from pydantic import BaseModel, Field 13 11 14 - from .cache_utils import get_model_path 15 12 from .hf_downloader import pull_model 16 13 17 14 from server.mem_agent.utils import ( 18 - extract_python_code, 19 - extract_reply, 20 - extract_thoughts, 21 15 create_memory_if_not_exists, 22 16 format_results, 23 17 ) 24 18 from server.mem_agent.engine import execute_sandboxed_code 25 19 26 - # Global model cache and configuration 27 - 28 - if sys.platform == "darwin": 29 - from .mlx_api import generate_chat_stream, get_or_load_model 30 - 20 + from . import runtime 31 21 32 22 logger = logging.getLogger("app") 33 23 _current_model_path: Optional[str] = None 34 24 _default_max_tokens: Optional[int] = None # Use dynamic model-aware limits by default 35 - _max_tool_turns = 5 36 25 _memory_path = "" 37 26 38 - 39 - class CompletionRequest(BaseModel): 40 - model: str 41 - prompt: Union[str, List[str]] 42 - max_tokens: Optional[int] = None 43 - temperature: Optional[float] = 0.7 44 - top_p: Optional[float] = 0.9 45 - stream: Optional[bool] = False 46 - stop: Optional[Union[str, List[str]]] = None 47 - repetition_penalty: Optional[float] = 1.1 48 - 49 - 50 - class ChatMessage(BaseModel): 51 - role: str = Field(..., pattern="^(system|user|assistant)$") 52 - content: str 53 - 54 - 55 27 _messages: list[ChatMessage] = [] 56 28 57 29 58 - class ChatCompletionRequest(BaseModel): 59 - model: str 60 - messages: List[ChatMessage] 61 - chat_start: bool 62 - python_code: str 63 - max_tokens: Optional[int] = None 64 - temperature: Optional[float] = 0.7 65 - top_p: Optional[float] = 0.9 66 - stream: Optional[bool] = False 67 - stop: Optional[Union[str, List[str]]] = None 68 - repetition_penalty: Optional[float] = 1.1 69 - 70 - 71 - class CompletionResponse(BaseModel): 72 - id: str 73 - object: str = "text_completion" 74 - created: int 75 - model: str 76 - choices: List[Dict[str, Any]] 77 - usage: Dict[str, int] 78 - 79 - 80 - class ChatCompletionResponse(BaseModel): 81 - id: str 82 - object: str = "chat.completion" 83 - created: int 84 - model: str 85 - choices: List[Dict[str, Any]] 86 - # usage: Dict[str, int] 87 - 88 - 89 - class ModelInfo(BaseModel): 90 - id: str 91 - object: str = "model" 92 - owned_by: str = "mlx-knife" 93 - permission: List = [] 94 - context_length: Optional[int] = None 95 - 96 - 97 - class StartRequest(BaseModel): 98 - model: str 99 - memory_path: str 100 - 101 - 102 - class downloadRequest(BaseModel): 103 - model: str 104 - 105 - 106 30 app = FastAPI() 107 31 108 32 ··· 114 38 @app.post("/download") 115 39 async def download(request: downloadRequest): 116 40 """Download the model""" 117 - try: 118 - if pull_model(request.model): 119 - return {"message": "Model downloaded"} 120 - else: 121 - raise HTTPException(status_code=400, detail="Downloading model failed") 122 - except Exception as e: 123 - raise HTTPException(status_code=500, detail=str(e)) 124 - 41 + runtime.backend.download_model(request.model) 125 42 126 43 @app.post("/start") 127 44 async def start_model(request: StartRequest): ··· 130 47 131 48 _messages = [ChatMessage(role="system", content=SYSTEM_PROMPT)] 132 49 _memory_path = request.memory_path 133 - 134 - get_or_load_model(request.model) 50 + logger.info(f"{runtime.backend}") 51 + runtime.backend.get_or_load_model(request.model) 135 52 return {"message": "Model loaded"} 136 53 137 54 138 55 @app.post("/v1/chat/completions") 139 56 async def create_chat_completion(request: ChatCompletionRequest): 140 57 """Create a chat completion.""" 141 - global _messages, _max_tool_turns, _memory_path 58 + global _messages, _memory_path 142 59 try: 143 60 144 61 if request.stream: ··· 157 74 158 75 # Streaming response 159 76 return StreamingResponse( 160 - generate_chat_stream(request.model, request.messages, request), 77 + runtime.backend.generate_chat_stream(_messages, request), 161 78 media_type="text/plain", 162 79 headers={"Cache-Control": "no-cache"}, 163 80 )
+1
server/backend/__init__.py
··· 1 +
+1
server/backend/linux.py
··· 1 + # Module for linux backend
+1128
server/backend/mlx_runner.py
··· 1 + """ 2 + Enhanced MLX model runner with direct API integration. 3 + Provides ollama-like run experience with streaming and interactive chat. 4 + """ 5 + 6 + import sys 7 + import json 8 + import os 9 + import time 10 + from collections.abc import Iterator 11 + from pathlib import Path 12 + from typing import Dict, Optional 13 + 14 + if sys.platform == "darwin": 15 + import mlx.core as mx 16 + else: 17 + mx = None 18 + from mlx_lm import load 19 + from mlx_lm.generate import generate_step 20 + from mlx_lm.sample_utils import make_repetition_penalty, make_sampler 21 + 22 + from ..reasoning_utils import ReasoningExtractor, StreamingReasoningParser 23 + 24 + 25 + def get_model_context_length(model_path: str) -> int: 26 + """Extract max_position_embeddings from model config. 27 + 28 + Args: 29 + model_path: Path to the MLX model directory 30 + 31 + Returns: 32 + Maximum context length for the model (defaults to 4096 if not found) 33 + """ 34 + config_path = os.path.join(model_path, "config.json") 35 + 36 + try: 37 + with open(config_path) as f: 38 + config = json.load(f) 39 + 40 + # Try various common config keys for context length 41 + context_keys = [ 42 + "max_position_embeddings", 43 + "n_positions", 44 + "context_length", 45 + "max_sequence_length", 46 + "seq_len", 47 + ] 48 + 49 + for key in context_keys: 50 + if key in config: 51 + return config[key] 52 + 53 + # If no context length found, return reasonable default 54 + return 4096 55 + 56 + except (FileNotFoundError, json.JSONDecodeError, KeyError): 57 + # Return default if config can't be read 58 + return 4096 59 + 60 + 61 + class MLXRunner: 62 + """Direct MLX model runner with streaming and interactive capabilities.""" 63 + 64 + def __init__( 65 + self, model_path: str, adapter_path: Optional[str] = None, verbose: bool = False 66 + ): 67 + """Initialize the runner with a model. 68 + 69 + Args: 70 + model_path: Path to the MLX model directory 71 + adapter_path: Optional path to LoRA adapter 72 + verbose: Show detailed output 73 + """ 74 + self.model_path = Path(model_path) 75 + self.adapter_path = adapter_path 76 + self.model = None 77 + self.tokenizer = None 78 + self._memory_baseline = None 79 + self._stop_tokens = None # Will be populated from tokenizer 80 + self._message_end_tokens = None # Message-end tokens (e.g., <|end|> for MXFP4) 81 + self._chat_stop_tokens = None # Chat-specific stop tokens 82 + self._context_length = None # Will be populated from model config 83 + self._is_reasoning_model = False # Whether model uses reasoning (MXFP4) 84 + self._reasoning_start = None # Reasoning start marker 85 + self._reasoning_end = None # Reasoning end marker 86 + self._final_start = None # Final answer start marker 87 + self.verbose = verbose 88 + self._model_loaded = False 89 + self._context_entered = False # Prevent nested context usage 90 + 91 + def __enter__(self): 92 + """Context manager entry - loads the model.""" 93 + if self._context_entered: 94 + raise RuntimeError( 95 + "MLXRunner context manager cannot be entered multiple times" 96 + ) 97 + 98 + self._context_entered = True 99 + try: 100 + self.load_model() 101 + return self 102 + except Exception: 103 + # If load_model fails, ensure cleanup happens 104 + self._context_entered = False 105 + self.cleanup() 106 + raise 107 + 108 + def __exit__(self, exc_type, exc_val, exc_tb): 109 + """Context manager exit - cleans up the model.""" 110 + self._context_entered = False 111 + self.cleanup() 112 + return False # Don't suppress exceptions 113 + 114 + def load_model(self): 115 + """Load the MLX model and tokenizer.""" 116 + if self._model_loaded: 117 + if self.verbose: 118 + print("Model already loaded, skipping...") 119 + return 120 + 121 + if self.verbose: 122 + print(f"Loading model from {self.model_path}...") 123 + start_time = time.time() 124 + 125 + # Capture baseline memory before loading 126 + try: 127 + mx.clear_cache() 128 + except Exception: 129 + pass # Continue even if cache clear fails 130 + self._memory_baseline = mx.get_active_memory() / 1024**3 131 + 132 + try: 133 + # Load model and tokenizer 134 + self.model, self.tokenizer = load( 135 + str(self.model_path), adapter_path=self.adapter_path 136 + ) 137 + 138 + load_time = time.time() - start_time 139 + current_memory = mx.get_active_memory() / 1024**3 140 + model_memory = current_memory - self._memory_baseline 141 + 142 + if self.verbose: 143 + print(f"Model loaded in {load_time:.1f}s") 144 + print( 145 + f"Memory: {model_memory:.1f}GB model, {current_memory:.1f}GB total" 146 + ) 147 + 148 + # Extract stop tokens from tokenizer 149 + self._extract_stop_tokens() 150 + 151 + # Extract context length from model config 152 + self._context_length = get_model_context_length(str(self.model_path)) 153 + 154 + if self.verbose: 155 + print(f"Model context length: {self._context_length} tokens") 156 + 157 + self._model_loaded = True 158 + 159 + except Exception as e: 160 + # Ensure partial state is cleaned up on failure 161 + self.model = None 162 + self.tokenizer = None 163 + self._stop_tokens = None 164 + self._model_loaded = False 165 + # Clear any memory that might have been allocated 166 + mx.clear_cache() 167 + raise RuntimeError( 168 + f"Failed to load model from {self.model_path}: {e}" 169 + ) from e 170 + 171 + def _extract_stop_tokens(self): 172 + """Extract stop tokens from the tokenizer dynamically. 173 + 174 + This method identifies ALL tokens that should stop generation: 175 + 1. Official EOS token from tokenizer config 176 + 2. Message-end tokens from training (e.g., <|end|> for MXFP4) 177 + 3. Common stop tokens across models 178 + """ 179 + self._stop_tokens = set() 180 + self._message_end_tokens = ( 181 + set() 182 + ) # Tokens that end messages but not conversations 183 + 184 + # Primary source: eos_token 185 + eos_token = getattr(self.tokenizer, "eos_token", None) 186 + if eos_token: 187 + self._stop_tokens.add(eos_token) 188 + 189 + # Also check pad_token if it's different from eos_token 190 + pad_token = getattr(self.tokenizer, "pad_token", None) 191 + if pad_token and pad_token != eos_token: 192 + self._stop_tokens.add(pad_token) 193 + 194 + # Check additional_special_tokens 195 + if hasattr(self.tokenizer, "additional_special_tokens"): 196 + for token in self.tokenizer.additional_special_tokens: 197 + if token and isinstance(token, str): 198 + # Only add tokens that look like stop/end tokens 199 + if any( 200 + keyword in token.lower() for keyword in ["end", "stop", "eot"] 201 + ): 202 + self._stop_tokens.add(token) 203 + 204 + # MLX-LM 0.27.0+: Extract tokens from added_tokens_decoder (comprehensive source) 205 + if hasattr(self.tokenizer, "added_tokens_decoder"): 206 + for _token_id, token_info in self.tokenizer.added_tokens_decoder.items(): 207 + if isinstance(token_info, dict) and "content" in token_info: 208 + token_content = token_info["content"] 209 + if token_content and isinstance(token_content, str): 210 + token_lower = token_content.lower() 211 + 212 + # NOTE: <|end|> is NOT a stop token for MXFP4 models! 213 + # It's a separator between reasoning and final answer 214 + if token_content == "<|end|>": 215 + self._message_end_tokens.add(token_content) 216 + # Do NOT add as stop token - let model continue to final answer 217 + 218 + # Look for tokens that could be end/stop tokens 219 + # Expanded patterns for MLX-LM 0.27.0 token varieties 220 + # EXCLUDE <|end|> for MXFP4 models as it's a reasoning separator 221 + end_patterns = [ 222 + "stop", 223 + "eot", 224 + "return", 225 + "finish", 226 + "done", 227 + "im_end", 228 + ] 229 + if any(pattern in token_lower for pattern in end_patterns): 230 + # Decide if it's a message-end or conversation-end token 231 + if "im_end" in token_lower: 232 + self._message_end_tokens.add(token_content) 233 + self._stop_tokens.add(token_content) 234 + # Special handling for 'end' pattern - more selective 235 + elif "end" in token_lower and token_content != "<|end|>": 236 + # Only add non-<|end|> tokens with 'end' in them 237 + self._stop_tokens.add(token_content) 238 + 239 + # Special case: control tokens in |..| format 240 + elif token_content.startswith("<|") and token_content.endswith( 241 + "|>" 242 + ): 243 + # Be inclusive with control tokens that might stop generation 244 + if any( 245 + pattern in token_lower 246 + for pattern in ["end", "return", "stop", "finish"] 247 + ): 248 + self._stop_tokens.add(token_content) 249 + 250 + # Model-specific handling based on known patterns 251 + # Use reasoning_utils for reasoning model detection and patterns 252 + from .reasoning_utils import ReasoningExtractor 253 + 254 + if hasattr(self.tokenizer, "name_or_path"): 255 + name_or_path = str(getattr(self.tokenizer, "name_or_path", "")).lower() 256 + model_type = ReasoningExtractor.detect_model_type(name_or_path) 257 + 258 + if model_type: 259 + # This is a reasoning model 260 + self._is_reasoning_model = True 261 + 262 + # Get patterns from reasoning_utils 263 + if model_type in ReasoningExtractor.PATTERNS: 264 + markers = ReasoningExtractor.PATTERNS[model_type]["markers"] 265 + self._reasoning_start = markers.get("reasoning_start") 266 + self._reasoning_end = markers.get("reasoning_end") 267 + self._final_start = markers.get("final_marker") 268 + 269 + # For reasoning models, remove reasoning_end from stop tokens 270 + if self._reasoning_end: 271 + self._stop_tokens.discard(self._reasoning_end) 272 + 273 + # Add proper stop token for this model type 274 + if model_type == "gpt-oss": 275 + if "<|return|>" not in self._stop_tokens: 276 + self._stop_tokens.add("<|return|>") 277 + else: 278 + self._is_reasoning_model = False 279 + else: 280 + self._is_reasoning_model = False 281 + 282 + # Add common stop tokens that might not be in special tokens 283 + common_stop_tokens = {"</s>", "<|endoftext|>", "<|im_end|>", "<|eot_id|>"} 284 + 285 + # Add chat-specific stop tokens to prevent model self-conversations 286 + # Based on our _format_conversation() format: "Human:" and "Assistant:" 287 + # Also include "You:" as models might use UI-visible format 288 + # Include single-letter variations (H:, A:, Y:) that some models use 289 + chat_stop_tokens = { 290 + "\nHuman:", 291 + "\nAssistant:", 292 + "\nYou:", 293 + "\n\nHuman:", 294 + "\n\nAssistant:", 295 + "\n\nYou:", 296 + "\nH:", 297 + "\nA:", 298 + "\nY:", # Single-letter variations 299 + "\n\nH:", 300 + "\n\nA:", 301 + "\n\nY:", 302 + } 303 + 304 + # Add common stop tokens only if they decode to themselves (i.e., they're single tokens) 305 + for token in common_stop_tokens: 306 + try: 307 + # Try to encode and decode to verify it's a real single token 308 + ids = self.tokenizer.encode(token, add_special_tokens=False) 309 + if ids and len(ids) == 1: # Single token ID means it's a special token 310 + decoded = self.tokenizer.decode(ids) 311 + if decoded == token: 312 + self._stop_tokens.add(token) 313 + except: 314 + pass 315 + 316 + # Store chat stop tokens separately - only used in interactive chat mode 317 + # This prevents stopping mid-story when user asks for dialogues 318 + self._chat_stop_tokens = list(chat_stop_tokens) 319 + 320 + # Remove any None values 321 + self._stop_tokens.discard(None) 322 + self._message_end_tokens.discard(None) 323 + 324 + # Convert to list for easier use 325 + self._stop_tokens = list(self._stop_tokens) 326 + self._message_end_tokens = list(self._message_end_tokens) 327 + 328 + if self.verbose: 329 + if self._stop_tokens: 330 + print(f"Stop tokens: {self._stop_tokens}") 331 + if self._message_end_tokens: 332 + print(f"Message end tokens: {self._message_end_tokens}") 333 + 334 + def cleanup(self): 335 + """Clean up model resources and clear GPU memory. 336 + 337 + This method is safe to call multiple times and handles partial state cleanup. 338 + """ 339 + if self.verbose and self._model_loaded: 340 + memory_before = mx.get_active_memory() / 1024**3 341 + print(f"Cleaning up model (memory before: {memory_before:.1f}GB)...") 342 + 343 + # Always clean up, even if model wasn't fully loaded 344 + self.model = None 345 + self.tokenizer = None 346 + self._stop_tokens = None 347 + self._message_end_tokens = None 348 + self._chat_stop_tokens = None 349 + self._context_length = None 350 + self._is_reasoning_model = False 351 + self._reasoning_start = None 352 + self._reasoning_end = None 353 + self._final_start = None 354 + self._model_loaded = False 355 + 356 + # Force garbage collection and clear MLX cache 357 + import gc 358 + 359 + gc.collect() 360 + try: 361 + mx.clear_cache() 362 + except Exception: 363 + pass # Continue cleanup even if cache clear fails 364 + 365 + if self.verbose: 366 + memory_after = mx.get_active_memory() / 1024**3 367 + if "memory_before" in locals(): 368 + memory_freed = memory_before - memory_after 369 + print( 370 + f"Cleanup complete (memory after: {memory_after:.1f}GB, freed: {memory_freed:.1f}GB)" 371 + ) 372 + else: 373 + print(f"Cleanup complete (memory after: {memory_after:.1f}GB)") 374 + 375 + def get_effective_max_tokens( 376 + self, requested_tokens: Optional[int], interactive: bool = False 377 + ) -> int: 378 + """Get effective max tokens based on model context and usage mode. 379 + 380 + Args: 381 + requested_tokens: The requested max tokens (None if user didn't specify --max-tokens) 382 + interactive: True if this is interactive mode (gets full context length) 383 + 384 + Returns: 385 + Effective max tokens to use 386 + """ 387 + if not self._context_length: 388 + # Fallback when context length is unknown 389 + fallback = 4096 if interactive else 2048 390 + if self.verbose: 391 + if requested_tokens is None: 392 + print( 393 + f"[WARNING] Model context length unknown, using fallback: {fallback} tokens" 394 + ) 395 + else: 396 + print( 397 + f"[WARNING] Model context length unknown, using user specified: {requested_tokens} tokens" 398 + ) 399 + return requested_tokens if requested_tokens is not None else fallback 400 + 401 + if interactive: 402 + if requested_tokens is None: 403 + # User didn't specify --max-tokens: use full model context 404 + return self._context_length 405 + else: 406 + # User specified --max-tokens explicitly: respect their choice but cap at context 407 + return min(requested_tokens, self._context_length) 408 + else: 409 + # Server/batch mode uses half context length for DoS protection 410 + server_limit = self._context_length // 2 411 + return min(requested_tokens or server_limit, server_limit) 412 + 413 + def generate_streaming( 414 + self, 415 + prompt: str, 416 + max_tokens: int = 500, 417 + temperature: float = 0.7, 418 + top_p: float = 0.9, 419 + repetition_penalty: float = 1.1, 420 + repetition_context_size: int = 20, 421 + use_chat_template: bool = True, 422 + use_chat_stop_tokens: bool = False, 423 + interactive: bool = False, 424 + hide_reasoning: bool = False, 425 + ) -> Iterator[str]: 426 + """Generate text with streaming output. 427 + 428 + Args: 429 + prompt: Input prompt 430 + max_tokens: Maximum tokens to generate 431 + temperature: Sampling temperature 432 + top_p: Top-p sampling parameter 433 + repetition_penalty: Penalty for repeated tokens 434 + repetition_context_size: Context size for repetition penalty 435 + use_chat_template: Apply tokenizer's chat template if available 436 + use_chat_stop_tokens: Include chat turn markers as stop tokens (for interactive mode) 437 + interactive: True if this is interactive mode (affects token limits) 438 + 439 + Yields: 440 + Generated tokens as they are produced 441 + """ 442 + if not self.model or not self.tokenizer: 443 + raise RuntimeError("Model not loaded. Call load_model() first.") 444 + 445 + # Initialize reasoning parser if this is a reasoning model 446 + reasoning_parser = None 447 + if self._is_reasoning_model: 448 + model_type = ReasoningExtractor.detect_model_type( 449 + getattr(self.tokenizer, "name_or_path", "") or "" 450 + ) 451 + reasoning_parser = StreamingReasoningParser( 452 + model_type, hide_reasoning=hide_reasoning 453 + ) 454 + 455 + # Apply context-aware token limits 456 + effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive) 457 + 458 + # Apply chat template if available and requested 459 + if ( 460 + use_chat_template 461 + and hasattr(self.tokenizer, "chat_template") 462 + and self.tokenizer.chat_template 463 + ): 464 + messages = [{"role": "user", "content": prompt}] 465 + formatted_prompt = self.tokenizer.apply_chat_template( 466 + messages, tokenize=False, add_generation_prompt=True 467 + ) 468 + else: 469 + formatted_prompt = prompt 470 + 471 + # Tokenize the prompt 472 + prompt_tokens = self.tokenizer.encode(formatted_prompt) 473 + prompt_array = mx.array(prompt_tokens) 474 + 475 + # Track generation metrics 476 + start_time = time.time() 477 + tokens_generated = 0 478 + 479 + # Create sampler with our parameters 480 + sampler = make_sampler(temp=temperature, top_p=top_p) 481 + 482 + # Create repetition penalty processor if needed 483 + logits_processors = [] 484 + if repetition_penalty > 1.0: 485 + logits_processors.append( 486 + make_repetition_penalty(repetition_penalty, repetition_context_size) 487 + ) 488 + 489 + # Generate tokens one by one for streaming 490 + generator = generate_step( 491 + prompt=prompt_array, 492 + model=self.model, 493 + max_tokens=effective_max_tokens, 494 + sampler=sampler, 495 + logits_processors=logits_processors if logits_processors else None, 496 + ) 497 + 498 + # Collect tokens and yield text 499 + generated_tokens = [] 500 + previous_decoded = "" 501 + accumulated_response = "" # Track full response for stop token detection 502 + 503 + # Keep a sliding window of recent tokens for context 504 + context_window = 10 # Decode last N tokens for proper spacing 505 + 506 + for token, _ in generator: 507 + # Token might be an array or an int 508 + token_id = token.item() if hasattr(token, "item") else token 509 + generated_tokens.append(token_id) 510 + 511 + # Use a sliding window approach for efficiency 512 + start_idx = max(0, len(generated_tokens) - context_window) 513 + window_tokens = generated_tokens[start_idx:] 514 + 515 + # Decode the window 516 + window_text = self.tokenizer.decode(window_tokens) 517 + 518 + # Figure out what's new 519 + if start_idx == 0: 520 + # We're still within the context window 521 + if window_text.startswith(previous_decoded): 522 + new_text = window_text[len(previous_decoded) :] 523 + else: 524 + new_text = self.tokenizer.decode([token_id]) 525 + previous_decoded = window_text 526 + else: 527 + # We're beyond the context window, just decode the last token with context 528 + # This is approximate but should preserve spaces 529 + new_text = self.tokenizer.decode(window_tokens) 530 + if len(window_tokens) > 1: 531 + prefix = self.tokenizer.decode(window_tokens[:-1]) 532 + if new_text.startswith(prefix): 533 + new_text = new_text[len(prefix) :] 534 + else: 535 + new_text = self.tokenizer.decode([token_id]) 536 + 537 + if new_text: 538 + # Update accumulated response for stop token checking 539 + accumulated_response += new_text 540 + 541 + # Filter out stop tokens with priority: native first, then chat fallback 542 + # Check native stop tokens FIRST in accumulated response (highest priority) 543 + native_stop_tokens = self._stop_tokens if self._stop_tokens else [] 544 + for stop_token in native_stop_tokens: 545 + if stop_token in accumulated_response: 546 + # Find the stop token position and yield everything before it 547 + stop_pos = accumulated_response.find(stop_token) 548 + # Calculate what text came before the stop token 549 + text_before_stop = accumulated_response[:stop_pos] 550 + # Calculate how much of that is new (not previously yielded) 551 + previously_yielded_length = len(accumulated_response) - len( 552 + new_text 553 + ) 554 + if len(text_before_stop) > previously_yielded_length: 555 + # Yield only the new part before stop token 556 + new_part_before_stop = text_before_stop[ 557 + previously_yielded_length: 558 + ] 559 + if new_part_before_stop: 560 + if reasoning_parser: 561 + # Process through reasoning parser for formatting 562 + for ( 563 + formatted_token 564 + ) in reasoning_parser.process_token( 565 + new_part_before_stop 566 + ): 567 + yield formatted_token 568 + else: 569 + yield new_part_before_stop 570 + return # Stop generation without yielding stop token 571 + 572 + # Only check chat stop tokens if no native stop token found (fallback) 573 + if use_chat_stop_tokens and self._chat_stop_tokens: 574 + for stop_token in self._chat_stop_tokens: 575 + if stop_token in accumulated_response: 576 + # Find the stop token position and yield everything before it 577 + stop_pos = accumulated_response.find(stop_token) 578 + # Calculate what text came before the stop token 579 + text_before_stop = accumulated_response[:stop_pos] 580 + # Calculate how much of that is new (not previously yielded) 581 + previously_yielded_length = len(accumulated_response) - len( 582 + new_text 583 + ) 584 + if len(text_before_stop) > previously_yielded_length: 585 + # Yield only the new part before stop token 586 + new_part_before_stop = text_before_stop[ 587 + previously_yielded_length: 588 + ] 589 + if new_part_before_stop: 590 + if reasoning_parser: 591 + # Process through reasoning parser for formatting 592 + for ( 593 + formatted_token 594 + ) in reasoning_parser.process_token( 595 + new_part_before_stop 596 + ): 597 + yield formatted_token 598 + else: 599 + yield new_part_before_stop 600 + return # Stop generation without yielding stop token 601 + 602 + # No stop token found, process the new text 603 + if reasoning_parser: 604 + # Process through reasoning parser for formatting 605 + for formatted_token in reasoning_parser.process_token(new_text): 606 + yield formatted_token 607 + else: 608 + # Normal streaming for non-reasoning models 609 + yield new_text 610 + tokens_generated += 1 611 + 612 + # Check for EOS token - don't yield it 613 + if token_id == self.tokenizer.eos_token_id: 614 + break 615 + 616 + # Finalize reasoning parser if used 617 + if reasoning_parser: 618 + yield from reasoning_parser.finalize() 619 + 620 + # Print generation statistics if verbose 621 + if self.verbose: 622 + generation_time = time.time() - start_time 623 + tokens_per_second = ( 624 + tokens_generated / generation_time if generation_time > 0 else 0 625 + ) 626 + print( 627 + f"\n\nGenerated {tokens_generated} tokens in {generation_time:.1f}s ({tokens_per_second:.1f} tokens/s)" 628 + ) 629 + 630 + def generate_batch( 631 + self, 632 + prompt: str, 633 + max_tokens: int = 500, 634 + temperature: float = 0.7, 635 + top_p: float = 0.9, 636 + repetition_penalty: float = 1.1, 637 + repetition_context_size: int = 20, 638 + use_chat_template: bool = True, 639 + interactive: bool = False, 640 + ) -> str: 641 + """Generate text in batch mode (non-streaming). 642 + 643 + Args: 644 + prompt: Input prompt 645 + max_tokens: Maximum tokens to generate 646 + temperature: Sampling temperature 647 + top_p: Top-p sampling parameter 648 + repetition_penalty: Penalty for repeated tokens 649 + repetition_context_size: Context size for repetition penalty 650 + use_chat_template: Apply tokenizer's chat template if available 651 + interactive: True if this is interactive mode (affects token limits) 652 + 653 + Returns: 654 + Generated text 655 + """ 656 + if not self.model or not self.tokenizer: 657 + raise RuntimeError("Model not loaded. Call load_model() first.") 658 + 659 + # Apply context-aware token limits 660 + effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive) 661 + 662 + # Apply chat template if available and requested 663 + if ( 664 + use_chat_template 665 + and hasattr(self.tokenizer, "chat_template") 666 + and self.tokenizer.chat_template 667 + ): 668 + messages = [{"role": "user", "content": prompt}] 669 + formatted_prompt = self.tokenizer.apply_chat_template( 670 + messages, tokenize=False, add_generation_prompt=True 671 + ) 672 + else: 673 + formatted_prompt = prompt 674 + 675 + start_time = time.time() 676 + 677 + # Tokenize the prompt 678 + prompt_tokens = self.tokenizer.encode(formatted_prompt) 679 + prompt_array = mx.array(prompt_tokens) 680 + 681 + # Create sampler with our parameters 682 + sampler = make_sampler(temp=temperature, top_p=top_p) 683 + 684 + # Create repetition penalty processor if needed 685 + logits_processors = [] 686 + if repetition_penalty > 1.0: 687 + logits_processors.append( 688 + make_repetition_penalty(repetition_penalty, repetition_context_size) 689 + ) 690 + 691 + # Generate all tokens at once 692 + generated_tokens = [] 693 + all_tokens = list(prompt_tokens) # Keep prompt for proper decoding 694 + 695 + generator = generate_step( 696 + prompt=prompt_array, 697 + model=self.model, 698 + max_tokens=effective_max_tokens, 699 + sampler=sampler, 700 + logits_processors=logits_processors if logits_processors else None, 701 + ) 702 + 703 + for token, _ in generator: 704 + # Token might be an array or an int 705 + token_id = token.item() if hasattr(token, "item") else token 706 + generated_tokens.append(token_id) 707 + all_tokens.append(token_id) 708 + 709 + # Check for EOS token - don't yield it 710 + if token_id == self.tokenizer.eos_token_id: 711 + break 712 + 713 + # Decode all tokens together for proper spacing 714 + full_response = self.tokenizer.decode(all_tokens) 715 + 716 + # Remove the prompt part 717 + if full_response.startswith(formatted_prompt): 718 + response = full_response[len(formatted_prompt) :] 719 + else: 720 + # Fallback: just decode generated tokens 721 + response = self.tokenizer.decode(generated_tokens) 722 + 723 + # Apply end-token filtering (same logic as streaming mode for Issue #20) 724 + response = self._filter_end_tokens_from_response( 725 + response, use_chat_stop_tokens=False 726 + ) 727 + 728 + # Format reasoning models output 729 + response = self._format_reasoning_response(response) 730 + 731 + generation_time = time.time() - start_time 732 + 733 + # Count tokens for statistics 734 + if self.verbose: 735 + tokens_generated = len(generated_tokens) 736 + tokens_per_second = ( 737 + tokens_generated / generation_time if generation_time > 0 else 0 738 + ) 739 + print( 740 + f"\nGenerated {tokens_generated} tokens in {generation_time:.1f}s ({tokens_per_second:.1f} tokens/s)" 741 + ) 742 + 743 + return response 744 + 745 + def interactive_chat( 746 + self, 747 + system_prompt: Optional[str] = None, 748 + max_tokens: int = 500, 749 + temperature: float = 0.7, 750 + top_p: float = 0.9, 751 + repetition_penalty: float = 1.1, 752 + use_chat_template: bool = True, 753 + ): 754 + """Run an interactive chat session. 755 + 756 + Args: 757 + system_prompt: Optional system prompt to prepend 758 + max_tokens: Maximum tokens per response 759 + temperature: Sampling temperature 760 + top_p: Top-p sampling parameter 761 + repetition_penalty: Penalty for repeated tokens 762 + use_chat_template: Use tokenizer's chat template if available 763 + """ 764 + print("Starting interactive chat. Type 'exit' or 'quit' to end.\n") 765 + 766 + conversation_history = [] 767 + if system_prompt: 768 + conversation_history.append({"role": "system", "content": system_prompt}) 769 + 770 + while True: 771 + try: 772 + # Get user input 773 + user_input = input("You: ").strip() 774 + 775 + if user_input.lower() in ["exit", "quit", "q"]: 776 + print("\nGoodbye!") 777 + break 778 + 779 + if not user_input: 780 + continue 781 + 782 + # Add user message to history 783 + conversation_history.append({"role": "user", "content": user_input}) 784 + 785 + # Format conversation for the model using chat template if available 786 + prompt = self._format_conversation( 787 + conversation_history, use_chat_template=use_chat_template 788 + ) 789 + 790 + # Generate response with streaming 791 + print("\nAssistant: ", end="", flush=True) 792 + 793 + response_tokens = [] 794 + for token in self.generate_streaming( 795 + prompt=prompt, 796 + max_tokens=max_tokens, 797 + temperature=temperature, 798 + top_p=top_p, 799 + repetition_penalty=repetition_penalty, 800 + use_chat_template=False, # Already applied in _format_conversation 801 + use_chat_stop_tokens=True, # Enable chat stop tokens in interactive mode 802 + interactive=True, # Enable full context length for interactive mode 803 + ): 804 + # Stream all tokens directly (already formatted by generate_streaming) 805 + print(token, end="", flush=True) 806 + response_tokens.append(token) 807 + 808 + # Add assistant response to history 809 + assistant_response = "".join(response_tokens).strip() 810 + conversation_history.append( 811 + {"role": "assistant", "content": assistant_response} 812 + ) 813 + 814 + print() # New line after response 815 + 816 + except KeyboardInterrupt: 817 + print("\n\nChat interrupted. Goodbye!") 818 + break 819 + except Exception as e: 820 + print(f"\n[ERROR] {e}") 821 + continue 822 + 823 + def _format_conversation( 824 + self, messages: list, use_chat_template: bool = True 825 + ) -> str: 826 + """Format conversation history into a prompt. 827 + 828 + Uses the tokenizer's chat template if available, otherwise falls back 829 + to the legacy Human:/Assistant: format for compatibility. 830 + 831 + Args: 832 + messages: List of message dictionaries with 'role' and 'content' 833 + use_chat_template: Whether to use chat template if available 834 + 835 + Returns: 836 + Formatted conversation string 837 + """ 838 + # Try to use native chat template if available 839 + if ( 840 + use_chat_template 841 + and hasattr(self.tokenizer, "chat_template") 842 + and self.tokenizer.chat_template 843 + ): 844 + try: 845 + # Apply the tokenizer's chat template 846 + formatted_prompt = self.tokenizer.apply_chat_template( 847 + messages, tokenize=False, add_generation_prompt=True 848 + ) 849 + return formatted_prompt 850 + except Exception as e: 851 + # If chat template fails, fall back to legacy format 852 + if self.verbose: 853 + print(f"[WARNING] Chat template failed, using legacy format: {e}") 854 + 855 + # Legacy format fallback for compatibility 856 + return self._legacy_format_conversation(messages) 857 + 858 + def _legacy_format_conversation(self, messages: list) -> str: 859 + """Legacy conversation formatting for backward compatibility. 860 + 861 + This format was used in earlier versions and remains as a fallback 862 + for models without chat templates. 863 + """ 864 + formatted = [] 865 + 866 + for message in messages: 867 + role = message["role"] 868 + content = message["content"] 869 + 870 + if role == "system": 871 + formatted.append(f"System: {content}") 872 + elif role == "user": 873 + formatted.append(f"Human: {content}") 874 + elif role == "assistant": 875 + formatted.append(f"Assistant: {content}") 876 + 877 + # Add prompt for next assistant response 878 + formatted.append("Assistant:") 879 + 880 + return "\n\n".join(formatted) 881 + 882 + def get_memory_usage(self) -> Dict[str, float]: 883 + """Get current memory usage statistics. 884 + 885 + Returns: 886 + Dictionary with memory statistics in GB 887 + """ 888 + try: 889 + current_memory = mx.get_active_memory() / 1024**3 890 + peak_memory = mx.get_peak_memory() / 1024**3 891 + except Exception: 892 + # Return zeros if memory stats unavailable 893 + current_memory = 0.0 894 + peak_memory = 0.0 895 + 896 + return { 897 + "current_gb": current_memory, 898 + "peak_gb": peak_memory, 899 + "model_gb": ( 900 + current_memory - self._memory_baseline if self._memory_baseline else 0 901 + ), 902 + } 903 + 904 + def _format_reasoning_response(self, response: str) -> str: 905 + """Format response from reasoning models for better readability. 906 + 907 + For MXFP4 models that generate reasoning followed by final answer, 908 + format it nicely for display. 909 + """ 910 + if not self._is_reasoning_model: 911 + return response 912 + 913 + # Check if response contains reasoning markers 914 + if self._reasoning_start in response and self._final_start in response: 915 + # Extract reasoning and final parts 916 + try: 917 + # Split on the reasoning start 918 + before_reasoning, after_start = response.split(self._reasoning_start, 1) 919 + 920 + # Find the reasoning content (until <|end|>) 921 + if self._reasoning_end in after_start: 922 + reasoning_content, after_reasoning = after_start.split( 923 + self._reasoning_end, 1 924 + ) 925 + 926 + # Find the final answer 927 + if self._final_start in after_reasoning: 928 + # Extract everything after final marker 929 + final_parts = after_reasoning.split(self._final_start, 1) 930 + if len(final_parts) > 1: 931 + # Remove the <|channel|>final<|message|> marker 932 + final_answer = final_parts[1].replace( 933 + "<|channel|>final<|message|>", "", 1 934 + ) 935 + 936 + # Format with clear markers for parsing but minimal visual impact 937 + formatted = [] 938 + formatted.append("\n**[Reasoning]**\n") 939 + formatted.append(reasoning_content.strip()) 940 + formatted.append("\n\n---\n\n**[Answer]**\n") 941 + formatted.append(final_answer.strip()) 942 + 943 + return "\n".join(formatted) 944 + except Exception: 945 + # If parsing fails, return original 946 + pass 947 + 948 + # Fallback: just clean up the control tokens 949 + cleaned = response 950 + for marker in [ 951 + "<|channel|>analysis<|message|>", 952 + "<|end|>", 953 + "<|start|>assistant", 954 + "<|channel|>final<|message|>", 955 + "<|return|>", 956 + ]: 957 + cleaned = cleaned.replace(marker, "") 958 + 959 + return cleaned.strip() 960 + 961 + def _filter_end_tokens_from_response( 962 + self, response: str, use_chat_stop_tokens: bool = False 963 + ) -> str: 964 + """Filter end tokens from a complete response (batch mode). 965 + 966 + This method applies the same filtering logic as the streaming mode 967 + to ensure consistent behavior between streaming and non-streaming. 968 + 969 + Args: 970 + response: The complete generated response 971 + use_chat_stop_tokens: Whether to apply chat stop tokens 972 + 973 + Returns: 974 + Response with end tokens filtered out 975 + """ 976 + # Apply native stop token filtering FIRST (highest priority) 977 + native_stop_tokens = self._stop_tokens if self._stop_tokens else [] 978 + for stop_token in native_stop_tokens: 979 + if stop_token in response: 980 + # Find the stop token position and return everything before it 981 + stop_pos = response.find(stop_token) 982 + filtered_response = response[:stop_pos].rstrip() 983 + if self.verbose: 984 + print( 985 + f"[DEBUG] Filtered stop token '{stop_token}' at position {stop_pos}" 986 + ) 987 + return filtered_response 988 + 989 + # Only check chat stop tokens if no native stop token found (fallback) 990 + if use_chat_stop_tokens and self._chat_stop_tokens: 991 + for stop_token in self._chat_stop_tokens: 992 + if stop_token in response: 993 + # Find the stop token position and return everything before it 994 + stop_pos = response.find(stop_token) 995 + return response[:stop_pos] 996 + 997 + # No stop tokens found, return original response 998 + return response 999 + 1000 + 1001 + def get_gpu_status() -> Dict[str, float]: 1002 + """Independent GPU status check - usable from anywhere. 1003 + 1004 + Returns: 1005 + Dictionary with GPU memory statistics in GB 1006 + """ 1007 + return { 1008 + "active_memory_gb": mx.get_active_memory() / 1024**3, 1009 + "peak_memory_gb": mx.get_peak_memory() / 1024**3, 1010 + } 1011 + 1012 + 1013 + def check_memory_available(required_gb: float) -> bool: 1014 + """Pre-flight check before model loading. 1015 + 1016 + Args: 1017 + required_gb: Required memory in GB 1018 + 1019 + Returns: 1020 + True if memory is likely available (conservative estimate) 1021 + """ 1022 + current_memory = mx.get_active_memory() / 1024**3 1023 + 1024 + # Conservative estimate: assume system has at least 8GB unified memory 1025 + # and we should leave some headroom (2GB) for system processes 1026 + estimated_total = 8.0 # This could be improved by detecting actual system memory 1027 + available = estimated_total - current_memory - 2.0 # 2GB headroom 1028 + 1029 + return available >= required_gb 1030 + 1031 + 1032 + def run_model_enhanced( 1033 + model_path: str, 1034 + prompt: Optional[str] = None, 1035 + interactive: bool = False, 1036 + max_tokens: int = 500, 1037 + temperature: float = 0.7, 1038 + top_p: float = 0.9, 1039 + repetition_penalty: float = 1.1, 1040 + stream: bool = True, 1041 + use_chat_template: bool = True, 1042 + hide_reasoning: bool = False, 1043 + verbose: bool = False, 1044 + ) -> Optional[str]: 1045 + """Enhanced run function with direct MLX integration. 1046 + 1047 + Uses context manager pattern for automatic resource cleanup. 1048 + 1049 + Args: 1050 + model_path: Path to the MLX model 1051 + prompt: Input prompt (if None, enters interactive mode) 1052 + interactive: Force interactive mode 1053 + max_tokens: Maximum tokens to generate 1054 + temperature: Sampling temperature 1055 + top_p: Top-p sampling parameter 1056 + repetition_penalty: Penalty for repeated tokens 1057 + stream: Whether to stream output 1058 + 1059 + Returns: 1060 + Generated text (in non-interactive mode) 1061 + """ 1062 + try: 1063 + with MLXRunner(model_path, verbose=verbose) as runner: 1064 + # Interactive mode 1065 + if interactive or prompt is None: 1066 + runner.interactive_chat( 1067 + max_tokens=max_tokens, 1068 + temperature=temperature, 1069 + top_p=top_p, 1070 + repetition_penalty=repetition_penalty, 1071 + use_chat_template=use_chat_template, 1072 + ) 1073 + return None 1074 + 1075 + # Single prompt mode 1076 + if verbose: 1077 + print(f"\nPrompt: {prompt}\n") 1078 + print("Response: ", end="", flush=True) 1079 + 1080 + if stream: 1081 + # Streaming generation 1082 + response_tokens = [] 1083 + try: 1084 + for token in runner.generate_streaming( 1085 + prompt=prompt, 1086 + max_tokens=max_tokens, 1087 + temperature=temperature, 1088 + top_p=top_p, 1089 + repetition_penalty=repetition_penalty, 1090 + use_chat_template=use_chat_template, 1091 + hide_reasoning=hide_reasoning, 1092 + ): 1093 + # Stream all tokens directly (already formatted by generate_streaming) 1094 + print(token, end="", flush=True) 1095 + response_tokens.append(token) 1096 + except KeyboardInterrupt: 1097 + print("\n[INFO] Generation interrupted by user.") 1098 + response = "".join(response_tokens) 1099 + else: 1100 + # Batch generation 1101 + try: 1102 + response = runner.generate_batch( 1103 + prompt=prompt, 1104 + max_tokens=max_tokens, 1105 + temperature=temperature, 1106 + top_p=top_p, 1107 + repetition_penalty=repetition_penalty, 1108 + use_chat_template=use_chat_template, 1109 + ) 1110 + except KeyboardInterrupt: 1111 + print("\n[INFO] Generation interrupted by user.") 1112 + response = "" 1113 + print(response) 1114 + 1115 + # Show memory usage if verbose 1116 + if verbose: 1117 + memory_stats = runner.get_memory_usage() 1118 + print( 1119 + f"\n\nMemory: {memory_stats['model_gb']:.1f}GB model, {memory_stats['current_gb']:.1f}GB total" 1120 + ) 1121 + 1122 + return response 1123 + 1124 + # Note: cleanup happens automatically due to context manager 1125 + 1126 + except Exception as e: 1127 + print(f"\n[ERROR] {e}") 1128 + return None
+19 -2
server/main.py
··· 1 1 import uvicorn 2 + 3 + # from backend import linux 2 4 from .api import app 3 5 from .config import PORT 4 6 import logging 5 7 import sys 6 8 from fastapi import Request 9 + from . import runtime 7 10 8 - # --- logging setup --- 9 11 logging.basicConfig( 10 12 level=logging.INFO, 11 13 format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ··· 14 16 logger = logging.getLogger("app") 15 17 16 18 17 - # --- middleware for request logging --- 18 19 @app.middleware("http") 19 20 async def log_requests(request: Request, call_next): 20 21 try: ··· 35 36 logger.info(f"<-- {request.method} {request.url.path} {response.status_code}") 36 37 return response 37 38 39 + def get_backend(): 40 + """ 41 + Dynamically choose which backend should be used depending on the OS 42 + """ 43 + if sys.platform == "darwin": 44 + from .backend import mlx 45 + logger.info("Using MLX backend (MacOs)") 46 + return mlx 47 + elif sys.platform.startswith("linux"): 48 + from .backend import linux 49 + logger.info(f"Using linux backend {sys.platform}") 50 + return linux 51 + else: 52 + raise RuntimeError(f"Unsupported OS: {sys.platform}") 53 + 54 + runtime.backend = get_backend() 38 55 39 56 def run(): 40 57 uvicorn.run(app, host="127.0.0.1", port=PORT)
+32 -60
server/mlx_api.py server/backend/mlx.py
··· 1 - from .config import SYSTEM_PROMPT 1 + from .mlx_runner import MLXRunner 2 + from ..cache_utils import get_model_path 3 + from fastapi import HTTPException 4 + from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest 5 + from ..hf_downloader import pull_model 6 + 2 7 import logging 3 8 import json 4 9 import time 5 10 import uuid 6 11 from collections.abc import AsyncGenerator 7 - from typing import Any, Dict, List, Optional, Union 8 - 9 - from pydantic import BaseModel, Field 10 - from .mlx_runner import MLXRunner 11 - from .cache_utils import get_model_path 12 - from fastapi import HTTPException 13 12 14 13 logger = logging.getLogger("app") 15 - _model_cache: Dict[str, MLXRunner] = {} 16 - _current_model_path: Optional[str] = None 17 - _default_max_tokens: Optional[int] = None # Use dynamic model-aware limits by default 18 - 19 - _runner: MLXRunner = {} 20 - 21 - 22 - class ChatMessage(BaseModel): 23 - role: str = Field(..., pattern="^(system|user|assistant)$") 24 - content: str 25 14 26 - 27 - _messages: list[ChatMessage] = [] 28 - 29 - 30 - class ChatCompletionRequest(BaseModel): 31 - model: str 32 - messages: List[ChatMessage] 33 - chat_start: bool 34 - python_code: str 35 - max_tokens: Optional[int] = None 36 - temperature: Optional[float] = 0.7 37 - top_p: Optional[float] = 0.9 38 - stream: Optional[bool] = False 39 - stop: Optional[Union[str, List[str]]] = None 40 - repetition_penalty: Optional[float] = 1.1 41 - 42 - 43 - class StartRequest(BaseModel): 44 - model: str 45 - memory_path: str 46 - 47 - 48 - def format_chat_messages_for_runner( 49 - messages: List[ChatMessage], 50 - ) -> List[Dict[str, str]]: 51 - """Convert chat messages to format expected by MLXRunner. 15 + from typing import Any, Dict, List, Optional, Union 52 16 53 - Returns messages in dict format for the runner to apply chat templates. 54 - """ 55 - return [{"role": msg.role, "content": msg.content} for msg in messages] 17 + _model_cache: Dict[str, MLXRunner] = {} 18 + _default_max_tokens: Optional[int] = None # Use dynamic model-aware limits by default 19 + _current_model_path: Optional[str] = None 56 20 57 21 58 - def count_tokens(text: str) -> int: 59 - """Rough token count estimation.""" 60 - return int(len(text.split()) * 1.3) # Approximation, convert to int 22 + def download_model(model_name: str): 23 + """Download the model""" 24 + if pull_model(model_name): 25 + return {"message": "Model downloaded"} 26 + else: 27 + raise HTTPException(status_code=400, detail="Downloading model failed") 61 28 62 29 63 30 def get_or_load_model(model_spec: str, verbose: bool = False) -> MLXRunner: ··· 111 78 112 79 return _model_cache[model_path_str] 113 80 114 - 115 - def start_model(request: StartRequest): 116 - """Load the model and start the agent""" 117 - global _runner 118 - 119 - _runner = get_or_load_model(request.model) 120 - return {"message": "Model loaded"} 121 - 122 - 123 81 async def generate_chat_stream( 124 - model: str, messages: List[ChatMessage], request: ChatCompletionRequest 82 + messages: List[ChatMessage], request: ChatCompletionRequest 125 83 ) -> AsyncGenerator[str, None]: 126 84 """Generate streaming chat completion response.""" 127 85 128 - global _messages 86 + _messages = messages 129 87 completion_id = f"chatcmpl-{uuid.uuid4()}" 130 88 created = int(time.time()) 131 89 runner = get_or_load_model(request.model) ··· 209 167 210 168 yield f"data: {json.dumps(final_response)}\n\n" 211 169 yield "data: [DONE]\n\n" 170 + 171 + def format_chat_messages_for_runner( 172 + messages: List[ChatMessage], 173 + ) -> List[Dict[str, str]]: 174 + """Convert chat messages to format expected by MLXRunner. 175 + 176 + Returns messages in dict format for the runner to apply chat templates. 177 + """ 178 + return [{"role": msg.role, "content": msg.content} for msg in messages] 179 + 180 + 181 + def count_tokens(text: str) -> int: 182 + """Rough token count estimation.""" 183 + return int(len(text.split()) * 1.3) # Approximation, convert to int
+1 -1
server/mlx_runner.py
··· 19 19 from mlx_lm.generate import generate_step 20 20 from mlx_lm.sample_utils import make_repetition_penalty, make_sampler 21 21 22 - from .reasoning_utils import ReasoningExtractor, StreamingReasoningParser 22 + from ..reasoning_utils import ReasoningExtractor, StreamingReasoningParser 23 23 24 24 25 25 def get_model_context_length(model_path: str) -> int:
+1
server/runtime.py
··· 1 + backend = None
+65
server/schemas.py
··· 1 + from pydantic import BaseModel, Field 2 + from typing import Any, Dict, List, Optional, Union 3 + 4 + class CompletionRequest(BaseModel): 5 + model: str 6 + prompt: Union[str, List[str]] 7 + max_tokens: Optional[int] = None 8 + temperature: Optional[float] = 0.7 9 + top_p: Optional[float] = 0.9 10 + stream: Optional[bool] = False 11 + stop: Optional[Union[str, List[str]]] = None 12 + repetition_penalty: Optional[float] = 1.1 13 + 14 + 15 + class ChatMessage(BaseModel): 16 + role: str = Field(..., pattern="^(system|user|assistant)$") 17 + content: str 18 + 19 + 20 + class ChatCompletionRequest(BaseModel): 21 + model: str 22 + messages: List[ChatMessage] 23 + chat_start: bool 24 + python_code: str 25 + max_tokens: Optional[int] = None 26 + temperature: Optional[float] = 0.7 27 + top_p: Optional[float] = 0.9 28 + stream: Optional[bool] = False 29 + stop: Optional[Union[str, List[str]]] = None 30 + repetition_penalty: Optional[float] = 1.1 31 + 32 + 33 + class CompletionResponse(BaseModel): 34 + id: str 35 + object: str = "text_completion" 36 + created: int 37 + model: str 38 + choices: List[Dict[str, Any]] 39 + usage: Dict[str, int] 40 + 41 + 42 + class ChatCompletionResponse(BaseModel): 43 + id: str 44 + object: str = "chat.completion" 45 + created: int 46 + model: str 47 + choices: List[Dict[str, Any]] 48 + # usage: Dict[str, int] 49 + 50 + 51 + class ModelInfo(BaseModel): 52 + id: str 53 + object: str = "model" 54 + owned_by: str = "mlx-knife" 55 + permission: List = [] 56 + context_length: Optional[int] = None 57 + 58 + 59 + class StartRequest(BaseModel): 60 + model: str 61 + memory_path: str 62 + 63 + 64 + class downloadRequest(BaseModel): 65 + model: str
-16
src/runner/mlx.rs
··· 101 101 let stdout_log = File::create(config_dir.join("server.out.log"))?; 102 102 let stderr_log = File::create(config_dir.join("server.err.log"))?; 103 103 let server_path = server_dir.join(".venv/bin/python3"); 104 - println!("{:?}", server_path); 105 104 server_dir.pop(); 106 105 let child = Command::new(server_path) 107 106 .args(["-m", "server.main"]) ··· 274 273 275 274 let mut stream = res.bytes_stream(); 276 275 let mut accumulated = String::new(); 277 - // let mut inside_python = false; 278 - // let mut tag_buffer = String::new(); 279 276 println!(); 280 277 while let Some(chunk) = stream.next().await { 281 278 let chunk = chunk.unwrap(); ··· 300 297 } 301 298 } 302 299 } 303 - // println!("{:?}", res); 304 - // if res.status() == 200 { 305 - // let text = res.text().await.unwrap(); 306 - // let v: Value = serde_json::from_str(&text).unwrap(); 307 - // let content = v["choices"][0]["message"]["content"] 308 - // .as_str() 309 - // .unwrap_or("<no content>"); 310 - 311 - // // Ok(convert_to_chat_response(content)) 312 - // } else { 313 - // // Err(String::from("request failed")) 314 - // } 315 - // unimplemented!() 316 300 Err(String::from("request failed")) 317 301 } 318 302