···1212 cargo test
13131414serve:
1515- uv run --project server python -m server.main
1515+ server/.venv/bin/python3 -m server.main
16161717bundle:
1818 ./scripts/bundler.sh
-1
server/api.py
···6161 if request.stream:
6262 result = ({}, "")
6363 if request.python_code:
6464- create_memory_if_not_exists()
6564 result = execute_sandboxed_code(
6665 code=request.python_code,
6766 allowed_path=_memory_path,
+1-1
server/backend/mlx_runner.py
···249249250250 # Model-specific handling based on known patterns
251251 # Use reasoning_utils for reasoning model detection and patterns
252252- from .reasoning_utils import ReasoningExtractor
252252+ from ..reasoning_utils import ReasoningExtractor
253253254254 if hasattr(self.tokenizer, "name_or_path"):
255255 name_or_path = str(getattr(self.tokenizer, "name_or_path", "")).lower()
-1128
server/mlx_runner.py
···11-"""
22-Enhanced MLX model runner with direct API integration.
33-Provides ollama-like run experience with streaming and interactive chat.
44-"""
55-66-import sys
77-import json
88-import os
99-import time
1010-from collections.abc import Iterator
1111-from pathlib import Path
1212-from typing import Dict, Optional
1313-1414-if sys.platform == "darwin":
1515- import mlx.core as mx
1616-else:
1717- mx = None
1818-from mlx_lm import load
1919-from mlx_lm.generate import generate_step
2020-from mlx_lm.sample_utils import make_repetition_penalty, make_sampler
2121-2222-from ..reasoning_utils import ReasoningExtractor, StreamingReasoningParser
2323-2424-2525-def get_model_context_length(model_path: str) -> int:
2626- """Extract max_position_embeddings from model config.
2727-2828- Args:
2929- model_path: Path to the MLX model directory
3030-3131- Returns:
3232- Maximum context length for the model (defaults to 4096 if not found)
3333- """
3434- config_path = os.path.join(model_path, "config.json")
3535-3636- try:
3737- with open(config_path) as f:
3838- config = json.load(f)
3939-4040- # Try various common config keys for context length
4141- context_keys = [
4242- "max_position_embeddings",
4343- "n_positions",
4444- "context_length",
4545- "max_sequence_length",
4646- "seq_len",
4747- ]
4848-4949- for key in context_keys:
5050- if key in config:
5151- return config[key]
5252-5353- # If no context length found, return reasonable default
5454- return 4096
5555-5656- except (FileNotFoundError, json.JSONDecodeError, KeyError):
5757- # Return default if config can't be read
5858- return 4096
5959-6060-6161-class MLXRunner:
6262- """Direct MLX model runner with streaming and interactive capabilities."""
6363-6464- def __init__(
6565- self, model_path: str, adapter_path: Optional[str] = None, verbose: bool = False
6666- ):
6767- """Initialize the runner with a model.
6868-6969- Args:
7070- model_path: Path to the MLX model directory
7171- adapter_path: Optional path to LoRA adapter
7272- verbose: Show detailed output
7373- """
7474- self.model_path = Path(model_path)
7575- self.adapter_path = adapter_path
7676- self.model = None
7777- self.tokenizer = None
7878- self._memory_baseline = None
7979- self._stop_tokens = None # Will be populated from tokenizer
8080- self._message_end_tokens = None # Message-end tokens (e.g., <|end|> for MXFP4)
8181- self._chat_stop_tokens = None # Chat-specific stop tokens
8282- self._context_length = None # Will be populated from model config
8383- self._is_reasoning_model = False # Whether model uses reasoning (MXFP4)
8484- self._reasoning_start = None # Reasoning start marker
8585- self._reasoning_end = None # Reasoning end marker
8686- self._final_start = None # Final answer start marker
8787- self.verbose = verbose
8888- self._model_loaded = False
8989- self._context_entered = False # Prevent nested context usage
9090-9191- def __enter__(self):
9292- """Context manager entry - loads the model."""
9393- if self._context_entered:
9494- raise RuntimeError(
9595- "MLXRunner context manager cannot be entered multiple times"
9696- )
9797-9898- self._context_entered = True
9999- try:
100100- self.load_model()
101101- return self
102102- except Exception:
103103- # If load_model fails, ensure cleanup happens
104104- self._context_entered = False
105105- self.cleanup()
106106- raise
107107-108108- def __exit__(self, exc_type, exc_val, exc_tb):
109109- """Context manager exit - cleans up the model."""
110110- self._context_entered = False
111111- self.cleanup()
112112- return False # Don't suppress exceptions
113113-114114- def load_model(self):
115115- """Load the MLX model and tokenizer."""
116116- if self._model_loaded:
117117- if self.verbose:
118118- print("Model already loaded, skipping...")
119119- return
120120-121121- if self.verbose:
122122- print(f"Loading model from {self.model_path}...")
123123- start_time = time.time()
124124-125125- # Capture baseline memory before loading
126126- try:
127127- mx.clear_cache()
128128- except Exception:
129129- pass # Continue even if cache clear fails
130130- self._memory_baseline = mx.get_active_memory() / 1024**3
131131-132132- try:
133133- # Load model and tokenizer
134134- self.model, self.tokenizer = load(
135135- str(self.model_path), adapter_path=self.adapter_path
136136- )
137137-138138- load_time = time.time() - start_time
139139- current_memory = mx.get_active_memory() / 1024**3
140140- model_memory = current_memory - self._memory_baseline
141141-142142- if self.verbose:
143143- print(f"Model loaded in {load_time:.1f}s")
144144- print(
145145- f"Memory: {model_memory:.1f}GB model, {current_memory:.1f}GB total"
146146- )
147147-148148- # Extract stop tokens from tokenizer
149149- self._extract_stop_tokens()
150150-151151- # Extract context length from model config
152152- self._context_length = get_model_context_length(str(self.model_path))
153153-154154- if self.verbose:
155155- print(f"Model context length: {self._context_length} tokens")
156156-157157- self._model_loaded = True
158158-159159- except Exception as e:
160160- # Ensure partial state is cleaned up on failure
161161- self.model = None
162162- self.tokenizer = None
163163- self._stop_tokens = None
164164- self._model_loaded = False
165165- # Clear any memory that might have been allocated
166166- mx.clear_cache()
167167- raise RuntimeError(
168168- f"Failed to load model from {self.model_path}: {e}"
169169- ) from e
170170-171171- def _extract_stop_tokens(self):
172172- """Extract stop tokens from the tokenizer dynamically.
173173-174174- This method identifies ALL tokens that should stop generation:
175175- 1. Official EOS token from tokenizer config
176176- 2. Message-end tokens from training (e.g., <|end|> for MXFP4)
177177- 3. Common stop tokens across models
178178- """
179179- self._stop_tokens = set()
180180- self._message_end_tokens = (
181181- set()
182182- ) # Tokens that end messages but not conversations
183183-184184- # Primary source: eos_token
185185- eos_token = getattr(self.tokenizer, "eos_token", None)
186186- if eos_token:
187187- self._stop_tokens.add(eos_token)
188188-189189- # Also check pad_token if it's different from eos_token
190190- pad_token = getattr(self.tokenizer, "pad_token", None)
191191- if pad_token and pad_token != eos_token:
192192- self._stop_tokens.add(pad_token)
193193-194194- # Check additional_special_tokens
195195- if hasattr(self.tokenizer, "additional_special_tokens"):
196196- for token in self.tokenizer.additional_special_tokens:
197197- if token and isinstance(token, str):
198198- # Only add tokens that look like stop/end tokens
199199- if any(
200200- keyword in token.lower() for keyword in ["end", "stop", "eot"]
201201- ):
202202- self._stop_tokens.add(token)
203203-204204- # MLX-LM 0.27.0+: Extract tokens from added_tokens_decoder (comprehensive source)
205205- if hasattr(self.tokenizer, "added_tokens_decoder"):
206206- for _token_id, token_info in self.tokenizer.added_tokens_decoder.items():
207207- if isinstance(token_info, dict) and "content" in token_info:
208208- token_content = token_info["content"]
209209- if token_content and isinstance(token_content, str):
210210- token_lower = token_content.lower()
211211-212212- # NOTE: <|end|> is NOT a stop token for MXFP4 models!
213213- # It's a separator between reasoning and final answer
214214- if token_content == "<|end|>":
215215- self._message_end_tokens.add(token_content)
216216- # Do NOT add as stop token - let model continue to final answer
217217-218218- # Look for tokens that could be end/stop tokens
219219- # Expanded patterns for MLX-LM 0.27.0 token varieties
220220- # EXCLUDE <|end|> for MXFP4 models as it's a reasoning separator
221221- end_patterns = [
222222- "stop",
223223- "eot",
224224- "return",
225225- "finish",
226226- "done",
227227- "im_end",
228228- ]
229229- if any(pattern in token_lower for pattern in end_patterns):
230230- # Decide if it's a message-end or conversation-end token
231231- if "im_end" in token_lower:
232232- self._message_end_tokens.add(token_content)
233233- self._stop_tokens.add(token_content)
234234- # Special handling for 'end' pattern - more selective
235235- elif "end" in token_lower and token_content != "<|end|>":
236236- # Only add non-<|end|> tokens with 'end' in them
237237- self._stop_tokens.add(token_content)
238238-239239- # Special case: control tokens in |..| format
240240- elif token_content.startswith("<|") and token_content.endswith(
241241- "|>"
242242- ):
243243- # Be inclusive with control tokens that might stop generation
244244- if any(
245245- pattern in token_lower
246246- for pattern in ["end", "return", "stop", "finish"]
247247- ):
248248- self._stop_tokens.add(token_content)
249249-250250- # Model-specific handling based on known patterns
251251- # Use reasoning_utils for reasoning model detection and patterns
252252- from .reasoning_utils import ReasoningExtractor
253253-254254- if hasattr(self.tokenizer, "name_or_path"):
255255- name_or_path = str(getattr(self.tokenizer, "name_or_path", "")).lower()
256256- model_type = ReasoningExtractor.detect_model_type(name_or_path)
257257-258258- if model_type:
259259- # This is a reasoning model
260260- self._is_reasoning_model = True
261261-262262- # Get patterns from reasoning_utils
263263- if model_type in ReasoningExtractor.PATTERNS:
264264- markers = ReasoningExtractor.PATTERNS[model_type]["markers"]
265265- self._reasoning_start = markers.get("reasoning_start")
266266- self._reasoning_end = markers.get("reasoning_end")
267267- self._final_start = markers.get("final_marker")
268268-269269- # For reasoning models, remove reasoning_end from stop tokens
270270- if self._reasoning_end:
271271- self._stop_tokens.discard(self._reasoning_end)
272272-273273- # Add proper stop token for this model type
274274- if model_type == "gpt-oss":
275275- if "<|return|>" not in self._stop_tokens:
276276- self._stop_tokens.add("<|return|>")
277277- else:
278278- self._is_reasoning_model = False
279279- else:
280280- self._is_reasoning_model = False
281281-282282- # Add common stop tokens that might not be in special tokens
283283- common_stop_tokens = {"</s>", "<|endoftext|>", "<|im_end|>", "<|eot_id|>"}
284284-285285- # Add chat-specific stop tokens to prevent model self-conversations
286286- # Based on our _format_conversation() format: "Human:" and "Assistant:"
287287- # Also include "You:" as models might use UI-visible format
288288- # Include single-letter variations (H:, A:, Y:) that some models use
289289- chat_stop_tokens = {
290290- "\nHuman:",
291291- "\nAssistant:",
292292- "\nYou:",
293293- "\n\nHuman:",
294294- "\n\nAssistant:",
295295- "\n\nYou:",
296296- "\nH:",
297297- "\nA:",
298298- "\nY:", # Single-letter variations
299299- "\n\nH:",
300300- "\n\nA:",
301301- "\n\nY:",
302302- }
303303-304304- # Add common stop tokens only if they decode to themselves (i.e., they're single tokens)
305305- for token in common_stop_tokens:
306306- try:
307307- # Try to encode and decode to verify it's a real single token
308308- ids = self.tokenizer.encode(token, add_special_tokens=False)
309309- if ids and len(ids) == 1: # Single token ID means it's a special token
310310- decoded = self.tokenizer.decode(ids)
311311- if decoded == token:
312312- self._stop_tokens.add(token)
313313- except:
314314- pass
315315-316316- # Store chat stop tokens separately - only used in interactive chat mode
317317- # This prevents stopping mid-story when user asks for dialogues
318318- self._chat_stop_tokens = list(chat_stop_tokens)
319319-320320- # Remove any None values
321321- self._stop_tokens.discard(None)
322322- self._message_end_tokens.discard(None)
323323-324324- # Convert to list for easier use
325325- self._stop_tokens = list(self._stop_tokens)
326326- self._message_end_tokens = list(self._message_end_tokens)
327327-328328- if self.verbose:
329329- if self._stop_tokens:
330330- print(f"Stop tokens: {self._stop_tokens}")
331331- if self._message_end_tokens:
332332- print(f"Message end tokens: {self._message_end_tokens}")
333333-334334- def cleanup(self):
335335- """Clean up model resources and clear GPU memory.
336336-337337- This method is safe to call multiple times and handles partial state cleanup.
338338- """
339339- if self.verbose and self._model_loaded:
340340- memory_before = mx.get_active_memory() / 1024**3
341341- print(f"Cleaning up model (memory before: {memory_before:.1f}GB)...")
342342-343343- # Always clean up, even if model wasn't fully loaded
344344- self.model = None
345345- self.tokenizer = None
346346- self._stop_tokens = None
347347- self._message_end_tokens = None
348348- self._chat_stop_tokens = None
349349- self._context_length = None
350350- self._is_reasoning_model = False
351351- self._reasoning_start = None
352352- self._reasoning_end = None
353353- self._final_start = None
354354- self._model_loaded = False
355355-356356- # Force garbage collection and clear MLX cache
357357- import gc
358358-359359- gc.collect()
360360- try:
361361- mx.clear_cache()
362362- except Exception:
363363- pass # Continue cleanup even if cache clear fails
364364-365365- if self.verbose:
366366- memory_after = mx.get_active_memory() / 1024**3
367367- if "memory_before" in locals():
368368- memory_freed = memory_before - memory_after
369369- print(
370370- f"Cleanup complete (memory after: {memory_after:.1f}GB, freed: {memory_freed:.1f}GB)"
371371- )
372372- else:
373373- print(f"Cleanup complete (memory after: {memory_after:.1f}GB)")
374374-375375- def get_effective_max_tokens(
376376- self, requested_tokens: Optional[int], interactive: bool = False
377377- ) -> int:
378378- """Get effective max tokens based on model context and usage mode.
379379-380380- Args:
381381- requested_tokens: The requested max tokens (None if user didn't specify --max-tokens)
382382- interactive: True if this is interactive mode (gets full context length)
383383-384384- Returns:
385385- Effective max tokens to use
386386- """
387387- if not self._context_length:
388388- # Fallback when context length is unknown
389389- fallback = 4096 if interactive else 2048
390390- if self.verbose:
391391- if requested_tokens is None:
392392- print(
393393- f"[WARNING] Model context length unknown, using fallback: {fallback} tokens"
394394- )
395395- else:
396396- print(
397397- f"[WARNING] Model context length unknown, using user specified: {requested_tokens} tokens"
398398- )
399399- return requested_tokens if requested_tokens is not None else fallback
400400-401401- if interactive:
402402- if requested_tokens is None:
403403- # User didn't specify --max-tokens: use full model context
404404- return self._context_length
405405- else:
406406- # User specified --max-tokens explicitly: respect their choice but cap at context
407407- return min(requested_tokens, self._context_length)
408408- else:
409409- # Server/batch mode uses half context length for DoS protection
410410- server_limit = self._context_length // 2
411411- return min(requested_tokens or server_limit, server_limit)
412412-413413- def generate_streaming(
414414- self,
415415- prompt: str,
416416- max_tokens: int = 500,
417417- temperature: float = 0.7,
418418- top_p: float = 0.9,
419419- repetition_penalty: float = 1.1,
420420- repetition_context_size: int = 20,
421421- use_chat_template: bool = True,
422422- use_chat_stop_tokens: bool = False,
423423- interactive: bool = False,
424424- hide_reasoning: bool = False,
425425- ) -> Iterator[str]:
426426- """Generate text with streaming output.
427427-428428- Args:
429429- prompt: Input prompt
430430- max_tokens: Maximum tokens to generate
431431- temperature: Sampling temperature
432432- top_p: Top-p sampling parameter
433433- repetition_penalty: Penalty for repeated tokens
434434- repetition_context_size: Context size for repetition penalty
435435- use_chat_template: Apply tokenizer's chat template if available
436436- use_chat_stop_tokens: Include chat turn markers as stop tokens (for interactive mode)
437437- interactive: True if this is interactive mode (affects token limits)
438438-439439- Yields:
440440- Generated tokens as they are produced
441441- """
442442- if not self.model or not self.tokenizer:
443443- raise RuntimeError("Model not loaded. Call load_model() first.")
444444-445445- # Initialize reasoning parser if this is a reasoning model
446446- reasoning_parser = None
447447- if self._is_reasoning_model:
448448- model_type = ReasoningExtractor.detect_model_type(
449449- getattr(self.tokenizer, "name_or_path", "") or ""
450450- )
451451- reasoning_parser = StreamingReasoningParser(
452452- model_type, hide_reasoning=hide_reasoning
453453- )
454454-455455- # Apply context-aware token limits
456456- effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive)
457457-458458- # Apply chat template if available and requested
459459- if (
460460- use_chat_template
461461- and hasattr(self.tokenizer, "chat_template")
462462- and self.tokenizer.chat_template
463463- ):
464464- messages = [{"role": "user", "content": prompt}]
465465- formatted_prompt = self.tokenizer.apply_chat_template(
466466- messages, tokenize=False, add_generation_prompt=True
467467- )
468468- else:
469469- formatted_prompt = prompt
470470-471471- # Tokenize the prompt
472472- prompt_tokens = self.tokenizer.encode(formatted_prompt)
473473- prompt_array = mx.array(prompt_tokens)
474474-475475- # Track generation metrics
476476- start_time = time.time()
477477- tokens_generated = 0
478478-479479- # Create sampler with our parameters
480480- sampler = make_sampler(temp=temperature, top_p=top_p)
481481-482482- # Create repetition penalty processor if needed
483483- logits_processors = []
484484- if repetition_penalty > 1.0:
485485- logits_processors.append(
486486- make_repetition_penalty(repetition_penalty, repetition_context_size)
487487- )
488488-489489- # Generate tokens one by one for streaming
490490- generator = generate_step(
491491- prompt=prompt_array,
492492- model=self.model,
493493- max_tokens=effective_max_tokens,
494494- sampler=sampler,
495495- logits_processors=logits_processors if logits_processors else None,
496496- )
497497-498498- # Collect tokens and yield text
499499- generated_tokens = []
500500- previous_decoded = ""
501501- accumulated_response = "" # Track full response for stop token detection
502502-503503- # Keep a sliding window of recent tokens for context
504504- context_window = 10 # Decode last N tokens for proper spacing
505505-506506- for token, _ in generator:
507507- # Token might be an array or an int
508508- token_id = token.item() if hasattr(token, "item") else token
509509- generated_tokens.append(token_id)
510510-511511- # Use a sliding window approach for efficiency
512512- start_idx = max(0, len(generated_tokens) - context_window)
513513- window_tokens = generated_tokens[start_idx:]
514514-515515- # Decode the window
516516- window_text = self.tokenizer.decode(window_tokens)
517517-518518- # Figure out what's new
519519- if start_idx == 0:
520520- # We're still within the context window
521521- if window_text.startswith(previous_decoded):
522522- new_text = window_text[len(previous_decoded) :]
523523- else:
524524- new_text = self.tokenizer.decode([token_id])
525525- previous_decoded = window_text
526526- else:
527527- # We're beyond the context window, just decode the last token with context
528528- # This is approximate but should preserve spaces
529529- new_text = self.tokenizer.decode(window_tokens)
530530- if len(window_tokens) > 1:
531531- prefix = self.tokenizer.decode(window_tokens[:-1])
532532- if new_text.startswith(prefix):
533533- new_text = new_text[len(prefix) :]
534534- else:
535535- new_text = self.tokenizer.decode([token_id])
536536-537537- if new_text:
538538- # Update accumulated response for stop token checking
539539- accumulated_response += new_text
540540-541541- # Filter out stop tokens with priority: native first, then chat fallback
542542- # Check native stop tokens FIRST in accumulated response (highest priority)
543543- native_stop_tokens = self._stop_tokens if self._stop_tokens else []
544544- for stop_token in native_stop_tokens:
545545- if stop_token in accumulated_response:
546546- # Find the stop token position and yield everything before it
547547- stop_pos = accumulated_response.find(stop_token)
548548- # Calculate what text came before the stop token
549549- text_before_stop = accumulated_response[:stop_pos]
550550- # Calculate how much of that is new (not previously yielded)
551551- previously_yielded_length = len(accumulated_response) - len(
552552- new_text
553553- )
554554- if len(text_before_stop) > previously_yielded_length:
555555- # Yield only the new part before stop token
556556- new_part_before_stop = text_before_stop[
557557- previously_yielded_length:
558558- ]
559559- if new_part_before_stop:
560560- if reasoning_parser:
561561- # Process through reasoning parser for formatting
562562- for (
563563- formatted_token
564564- ) in reasoning_parser.process_token(
565565- new_part_before_stop
566566- ):
567567- yield formatted_token
568568- else:
569569- yield new_part_before_stop
570570- return # Stop generation without yielding stop token
571571-572572- # Only check chat stop tokens if no native stop token found (fallback)
573573- if use_chat_stop_tokens and self._chat_stop_tokens:
574574- for stop_token in self._chat_stop_tokens:
575575- if stop_token in accumulated_response:
576576- # Find the stop token position and yield everything before it
577577- stop_pos = accumulated_response.find(stop_token)
578578- # Calculate what text came before the stop token
579579- text_before_stop = accumulated_response[:stop_pos]
580580- # Calculate how much of that is new (not previously yielded)
581581- previously_yielded_length = len(accumulated_response) - len(
582582- new_text
583583- )
584584- if len(text_before_stop) > previously_yielded_length:
585585- # Yield only the new part before stop token
586586- new_part_before_stop = text_before_stop[
587587- previously_yielded_length:
588588- ]
589589- if new_part_before_stop:
590590- if reasoning_parser:
591591- # Process through reasoning parser for formatting
592592- for (
593593- formatted_token
594594- ) in reasoning_parser.process_token(
595595- new_part_before_stop
596596- ):
597597- yield formatted_token
598598- else:
599599- yield new_part_before_stop
600600- return # Stop generation without yielding stop token
601601-602602- # No stop token found, process the new text
603603- if reasoning_parser:
604604- # Process through reasoning parser for formatting
605605- for formatted_token in reasoning_parser.process_token(new_text):
606606- yield formatted_token
607607- else:
608608- # Normal streaming for non-reasoning models
609609- yield new_text
610610- tokens_generated += 1
611611-612612- # Check for EOS token - don't yield it
613613- if token_id == self.tokenizer.eos_token_id:
614614- break
615615-616616- # Finalize reasoning parser if used
617617- if reasoning_parser:
618618- yield from reasoning_parser.finalize()
619619-620620- # Print generation statistics if verbose
621621- if self.verbose:
622622- generation_time = time.time() - start_time
623623- tokens_per_second = (
624624- tokens_generated / generation_time if generation_time > 0 else 0
625625- )
626626- print(
627627- f"\n\nGenerated {tokens_generated} tokens in {generation_time:.1f}s ({tokens_per_second:.1f} tokens/s)"
628628- )
629629-630630- def generate_batch(
631631- self,
632632- prompt: str,
633633- max_tokens: int = 500,
634634- temperature: float = 0.7,
635635- top_p: float = 0.9,
636636- repetition_penalty: float = 1.1,
637637- repetition_context_size: int = 20,
638638- use_chat_template: bool = True,
639639- interactive: bool = False,
640640- ) -> str:
641641- """Generate text in batch mode (non-streaming).
642642-643643- Args:
644644- prompt: Input prompt
645645- max_tokens: Maximum tokens to generate
646646- temperature: Sampling temperature
647647- top_p: Top-p sampling parameter
648648- repetition_penalty: Penalty for repeated tokens
649649- repetition_context_size: Context size for repetition penalty
650650- use_chat_template: Apply tokenizer's chat template if available
651651- interactive: True if this is interactive mode (affects token limits)
652652-653653- Returns:
654654- Generated text
655655- """
656656- if not self.model or not self.tokenizer:
657657- raise RuntimeError("Model not loaded. Call load_model() first.")
658658-659659- # Apply context-aware token limits
660660- effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive)
661661-662662- # Apply chat template if available and requested
663663- if (
664664- use_chat_template
665665- and hasattr(self.tokenizer, "chat_template")
666666- and self.tokenizer.chat_template
667667- ):
668668- messages = [{"role": "user", "content": prompt}]
669669- formatted_prompt = self.tokenizer.apply_chat_template(
670670- messages, tokenize=False, add_generation_prompt=True
671671- )
672672- else:
673673- formatted_prompt = prompt
674674-675675- start_time = time.time()
676676-677677- # Tokenize the prompt
678678- prompt_tokens = self.tokenizer.encode(formatted_prompt)
679679- prompt_array = mx.array(prompt_tokens)
680680-681681- # Create sampler with our parameters
682682- sampler = make_sampler(temp=temperature, top_p=top_p)
683683-684684- # Create repetition penalty processor if needed
685685- logits_processors = []
686686- if repetition_penalty > 1.0:
687687- logits_processors.append(
688688- make_repetition_penalty(repetition_penalty, repetition_context_size)
689689- )
690690-691691- # Generate all tokens at once
692692- generated_tokens = []
693693- all_tokens = list(prompt_tokens) # Keep prompt for proper decoding
694694-695695- generator = generate_step(
696696- prompt=prompt_array,
697697- model=self.model,
698698- max_tokens=effective_max_tokens,
699699- sampler=sampler,
700700- logits_processors=logits_processors if logits_processors else None,
701701- )
702702-703703- for token, _ in generator:
704704- # Token might be an array or an int
705705- token_id = token.item() if hasattr(token, "item") else token
706706- generated_tokens.append(token_id)
707707- all_tokens.append(token_id)
708708-709709- # Check for EOS token - don't yield it
710710- if token_id == self.tokenizer.eos_token_id:
711711- break
712712-713713- # Decode all tokens together for proper spacing
714714- full_response = self.tokenizer.decode(all_tokens)
715715-716716- # Remove the prompt part
717717- if full_response.startswith(formatted_prompt):
718718- response = full_response[len(formatted_prompt) :]
719719- else:
720720- # Fallback: just decode generated tokens
721721- response = self.tokenizer.decode(generated_tokens)
722722-723723- # Apply end-token filtering (same logic as streaming mode for Issue #20)
724724- response = self._filter_end_tokens_from_response(
725725- response, use_chat_stop_tokens=False
726726- )
727727-728728- # Format reasoning models output
729729- response = self._format_reasoning_response(response)
730730-731731- generation_time = time.time() - start_time
732732-733733- # Count tokens for statistics
734734- if self.verbose:
735735- tokens_generated = len(generated_tokens)
736736- tokens_per_second = (
737737- tokens_generated / generation_time if generation_time > 0 else 0
738738- )
739739- print(
740740- f"\nGenerated {tokens_generated} tokens in {generation_time:.1f}s ({tokens_per_second:.1f} tokens/s)"
741741- )
742742-743743- return response
744744-745745- def interactive_chat(
746746- self,
747747- system_prompt: Optional[str] = None,
748748- max_tokens: int = 500,
749749- temperature: float = 0.7,
750750- top_p: float = 0.9,
751751- repetition_penalty: float = 1.1,
752752- use_chat_template: bool = True,
753753- ):
754754- """Run an interactive chat session.
755755-756756- Args:
757757- system_prompt: Optional system prompt to prepend
758758- max_tokens: Maximum tokens per response
759759- temperature: Sampling temperature
760760- top_p: Top-p sampling parameter
761761- repetition_penalty: Penalty for repeated tokens
762762- use_chat_template: Use tokenizer's chat template if available
763763- """
764764- print("Starting interactive chat. Type 'exit' or 'quit' to end.\n")
765765-766766- conversation_history = []
767767- if system_prompt:
768768- conversation_history.append({"role": "system", "content": system_prompt})
769769-770770- while True:
771771- try:
772772- # Get user input
773773- user_input = input("You: ").strip()
774774-775775- if user_input.lower() in ["exit", "quit", "q"]:
776776- print("\nGoodbye!")
777777- break
778778-779779- if not user_input:
780780- continue
781781-782782- # Add user message to history
783783- conversation_history.append({"role": "user", "content": user_input})
784784-785785- # Format conversation for the model using chat template if available
786786- prompt = self._format_conversation(
787787- conversation_history, use_chat_template=use_chat_template
788788- )
789789-790790- # Generate response with streaming
791791- print("\nAssistant: ", end="", flush=True)
792792-793793- response_tokens = []
794794- for token in self.generate_streaming(
795795- prompt=prompt,
796796- max_tokens=max_tokens,
797797- temperature=temperature,
798798- top_p=top_p,
799799- repetition_penalty=repetition_penalty,
800800- use_chat_template=False, # Already applied in _format_conversation
801801- use_chat_stop_tokens=True, # Enable chat stop tokens in interactive mode
802802- interactive=True, # Enable full context length for interactive mode
803803- ):
804804- # Stream all tokens directly (already formatted by generate_streaming)
805805- print(token, end="", flush=True)
806806- response_tokens.append(token)
807807-808808- # Add assistant response to history
809809- assistant_response = "".join(response_tokens).strip()
810810- conversation_history.append(
811811- {"role": "assistant", "content": assistant_response}
812812- )
813813-814814- print() # New line after response
815815-816816- except KeyboardInterrupt:
817817- print("\n\nChat interrupted. Goodbye!")
818818- break
819819- except Exception as e:
820820- print(f"\n[ERROR] {e}")
821821- continue
822822-823823- def _format_conversation(
824824- self, messages: list, use_chat_template: bool = True
825825- ) -> str:
826826- """Format conversation history into a prompt.
827827-828828- Uses the tokenizer's chat template if available, otherwise falls back
829829- to the legacy Human:/Assistant: format for compatibility.
830830-831831- Args:
832832- messages: List of message dictionaries with 'role' and 'content'
833833- use_chat_template: Whether to use chat template if available
834834-835835- Returns:
836836- Formatted conversation string
837837- """
838838- # Try to use native chat template if available
839839- if (
840840- use_chat_template
841841- and hasattr(self.tokenizer, "chat_template")
842842- and self.tokenizer.chat_template
843843- ):
844844- try:
845845- # Apply the tokenizer's chat template
846846- formatted_prompt = self.tokenizer.apply_chat_template(
847847- messages, tokenize=False, add_generation_prompt=True
848848- )
849849- return formatted_prompt
850850- except Exception as e:
851851- # If chat template fails, fall back to legacy format
852852- if self.verbose:
853853- print(f"[WARNING] Chat template failed, using legacy format: {e}")
854854-855855- # Legacy format fallback for compatibility
856856- return self._legacy_format_conversation(messages)
857857-858858- def _legacy_format_conversation(self, messages: list) -> str:
859859- """Legacy conversation formatting for backward compatibility.
860860-861861- This format was used in earlier versions and remains as a fallback
862862- for models without chat templates.
863863- """
864864- formatted = []
865865-866866- for message in messages:
867867- role = message["role"]
868868- content = message["content"]
869869-870870- if role == "system":
871871- formatted.append(f"System: {content}")
872872- elif role == "user":
873873- formatted.append(f"Human: {content}")
874874- elif role == "assistant":
875875- formatted.append(f"Assistant: {content}")
876876-877877- # Add prompt for next assistant response
878878- formatted.append("Assistant:")
879879-880880- return "\n\n".join(formatted)
881881-882882- def get_memory_usage(self) -> Dict[str, float]:
883883- """Get current memory usage statistics.
884884-885885- Returns:
886886- Dictionary with memory statistics in GB
887887- """
888888- try:
889889- current_memory = mx.get_active_memory() / 1024**3
890890- peak_memory = mx.get_peak_memory() / 1024**3
891891- except Exception:
892892- # Return zeros if memory stats unavailable
893893- current_memory = 0.0
894894- peak_memory = 0.0
895895-896896- return {
897897- "current_gb": current_memory,
898898- "peak_gb": peak_memory,
899899- "model_gb": (
900900- current_memory - self._memory_baseline if self._memory_baseline else 0
901901- ),
902902- }
903903-904904- def _format_reasoning_response(self, response: str) -> str:
905905- """Format response from reasoning models for better readability.
906906-907907- For MXFP4 models that generate reasoning followed by final answer,
908908- format it nicely for display.
909909- """
910910- if not self._is_reasoning_model:
911911- return response
912912-913913- # Check if response contains reasoning markers
914914- if self._reasoning_start in response and self._final_start in response:
915915- # Extract reasoning and final parts
916916- try:
917917- # Split on the reasoning start
918918- before_reasoning, after_start = response.split(self._reasoning_start, 1)
919919-920920- # Find the reasoning content (until <|end|>)
921921- if self._reasoning_end in after_start:
922922- reasoning_content, after_reasoning = after_start.split(
923923- self._reasoning_end, 1
924924- )
925925-926926- # Find the final answer
927927- if self._final_start in after_reasoning:
928928- # Extract everything after final marker
929929- final_parts = after_reasoning.split(self._final_start, 1)
930930- if len(final_parts) > 1:
931931- # Remove the <|channel|>final<|message|> marker
932932- final_answer = final_parts[1].replace(
933933- "<|channel|>final<|message|>", "", 1
934934- )
935935-936936- # Format with clear markers for parsing but minimal visual impact
937937- formatted = []
938938- formatted.append("\n**[Reasoning]**\n")
939939- formatted.append(reasoning_content.strip())
940940- formatted.append("\n\n---\n\n**[Answer]**\n")
941941- formatted.append(final_answer.strip())
942942-943943- return "\n".join(formatted)
944944- except Exception:
945945- # If parsing fails, return original
946946- pass
947947-948948- # Fallback: just clean up the control tokens
949949- cleaned = response
950950- for marker in [
951951- "<|channel|>analysis<|message|>",
952952- "<|end|>",
953953- "<|start|>assistant",
954954- "<|channel|>final<|message|>",
955955- "<|return|>",
956956- ]:
957957- cleaned = cleaned.replace(marker, "")
958958-959959- return cleaned.strip()
960960-961961- def _filter_end_tokens_from_response(
962962- self, response: str, use_chat_stop_tokens: bool = False
963963- ) -> str:
964964- """Filter end tokens from a complete response (batch mode).
965965-966966- This method applies the same filtering logic as the streaming mode
967967- to ensure consistent behavior between streaming and non-streaming.
968968-969969- Args:
970970- response: The complete generated response
971971- use_chat_stop_tokens: Whether to apply chat stop tokens
972972-973973- Returns:
974974- Response with end tokens filtered out
975975- """
976976- # Apply native stop token filtering FIRST (highest priority)
977977- native_stop_tokens = self._stop_tokens if self._stop_tokens else []
978978- for stop_token in native_stop_tokens:
979979- if stop_token in response:
980980- # Find the stop token position and return everything before it
981981- stop_pos = response.find(stop_token)
982982- filtered_response = response[:stop_pos].rstrip()
983983- if self.verbose:
984984- print(
985985- f"[DEBUG] Filtered stop token '{stop_token}' at position {stop_pos}"
986986- )
987987- return filtered_response
988988-989989- # Only check chat stop tokens if no native stop token found (fallback)
990990- if use_chat_stop_tokens and self._chat_stop_tokens:
991991- for stop_token in self._chat_stop_tokens:
992992- if stop_token in response:
993993- # Find the stop token position and return everything before it
994994- stop_pos = response.find(stop_token)
995995- return response[:stop_pos]
996996-997997- # No stop tokens found, return original response
998998- return response
999999-10001000-10011001-def get_gpu_status() -> Dict[str, float]:
10021002- """Independent GPU status check - usable from anywhere.
10031003-10041004- Returns:
10051005- Dictionary with GPU memory statistics in GB
10061006- """
10071007- return {
10081008- "active_memory_gb": mx.get_active_memory() / 1024**3,
10091009- "peak_memory_gb": mx.get_peak_memory() / 1024**3,
10101010- }
10111011-10121012-10131013-def check_memory_available(required_gb: float) -> bool:
10141014- """Pre-flight check before model loading.
10151015-10161016- Args:
10171017- required_gb: Required memory in GB
10181018-10191019- Returns:
10201020- True if memory is likely available (conservative estimate)
10211021- """
10221022- current_memory = mx.get_active_memory() / 1024**3
10231023-10241024- # Conservative estimate: assume system has at least 8GB unified memory
10251025- # and we should leave some headroom (2GB) for system processes
10261026- estimated_total = 8.0 # This could be improved by detecting actual system memory
10271027- available = estimated_total - current_memory - 2.0 # 2GB headroom
10281028-10291029- return available >= required_gb
10301030-10311031-10321032-def run_model_enhanced(
10331033- model_path: str,
10341034- prompt: Optional[str] = None,
10351035- interactive: bool = False,
10361036- max_tokens: int = 500,
10371037- temperature: float = 0.7,
10381038- top_p: float = 0.9,
10391039- repetition_penalty: float = 1.1,
10401040- stream: bool = True,
10411041- use_chat_template: bool = True,
10421042- hide_reasoning: bool = False,
10431043- verbose: bool = False,
10441044-) -> Optional[str]:
10451045- """Enhanced run function with direct MLX integration.
10461046-10471047- Uses context manager pattern for automatic resource cleanup.
10481048-10491049- Args:
10501050- model_path: Path to the MLX model
10511051- prompt: Input prompt (if None, enters interactive mode)
10521052- interactive: Force interactive mode
10531053- max_tokens: Maximum tokens to generate
10541054- temperature: Sampling temperature
10551055- top_p: Top-p sampling parameter
10561056- repetition_penalty: Penalty for repeated tokens
10571057- stream: Whether to stream output
10581058-10591059- Returns:
10601060- Generated text (in non-interactive mode)
10611061- """
10621062- try:
10631063- with MLXRunner(model_path, verbose=verbose) as runner:
10641064- # Interactive mode
10651065- if interactive or prompt is None:
10661066- runner.interactive_chat(
10671067- max_tokens=max_tokens,
10681068- temperature=temperature,
10691069- top_p=top_p,
10701070- repetition_penalty=repetition_penalty,
10711071- use_chat_template=use_chat_template,
10721072- )
10731073- return None
10741074-10751075- # Single prompt mode
10761076- if verbose:
10771077- print(f"\nPrompt: {prompt}\n")
10781078- print("Response: ", end="", flush=True)
10791079-10801080- if stream:
10811081- # Streaming generation
10821082- response_tokens = []
10831083- try:
10841084- for token in runner.generate_streaming(
10851085- prompt=prompt,
10861086- max_tokens=max_tokens,
10871087- temperature=temperature,
10881088- top_p=top_p,
10891089- repetition_penalty=repetition_penalty,
10901090- use_chat_template=use_chat_template,
10911091- hide_reasoning=hide_reasoning,
10921092- ):
10931093- # Stream all tokens directly (already formatted by generate_streaming)
10941094- print(token, end="", flush=True)
10951095- response_tokens.append(token)
10961096- except KeyboardInterrupt:
10971097- print("\n[INFO] Generation interrupted by user.")
10981098- response = "".join(response_tokens)
10991099- else:
11001100- # Batch generation
11011101- try:
11021102- response = runner.generate_batch(
11031103- prompt=prompt,
11041104- max_tokens=max_tokens,
11051105- temperature=temperature,
11061106- top_p=top_p,
11071107- repetition_penalty=repetition_penalty,
11081108- use_chat_template=use_chat_template,
11091109- )
11101110- except KeyboardInterrupt:
11111111- print("\n[INFO] Generation interrupted by user.")
11121112- response = ""
11131113- print(response)
11141114-11151115- # Show memory usage if verbose
11161116- if verbose:
11171117- memory_stats = runner.get_memory_usage()
11181118- print(
11191119- f"\n\nMemory: {memory_stats['model_gb']:.1f}GB model, {memory_stats['current_gb']:.1f}GB total"
11201120- )
11211121-11221122- return response
11231123-11241124- # Note: cleanup happens automatically due to context manager
11251125-11261126- except Exception as e:
11271127- print(f"\n[ERROR] {e}")
11281128- return None