A safe, simple, extensible, and fast agent harness
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

Improve the test suite ahead of the first public release (#26)

authored by

Mason Stallmo and committed by
GitHub
674b1775 a363ae2f

+1526 -66
+5
.config/nextest.toml
··· 1 + [profile.default] 2 + fail-fast = false 3 + 4 + [profile.ci] 5 + fail-fast = true
+27
.github/workflows/test.yml
··· 1 + name: Test 2 + 3 + on: 4 + push: 5 + branches: [main] 6 + pull_request: 7 + branches: [main] 8 + 9 + jobs: 10 + test: 11 + runs-on: ubuntu-22.04 12 + steps: 13 + - uses: actions/checkout@v4 14 + 15 + - name: Install protoc 16 + uses: arduino/setup-protoc@v3 17 + with: 18 + repo-token: ${{ secrets.GITHUB_TOKEN }} 19 + 20 + - name: Add wasm32-wasip2 target 21 + run: rustup target add wasm32-wasip2 22 + 23 + - name: Install cargo-nextest 24 + uses: taiki-e/install-action@nextest 25 + 26 + - name: Run tests 27 + run: cargo nextest run --workspace --profile ci
+3
Cargo.lock
··· 1094 1094 "ein_plugin", 1095 1095 "serde", 1096 1096 "serde_json", 1097 + "tempfile", 1097 1098 "wit-bindgen 0.53.1", 1098 1099 ] 1099 1100 ··· 1160 1161 "ein_plugin", 1161 1162 "serde", 1162 1163 "serde_json", 1164 + "tempfile", 1163 1165 "wit-bindgen 0.53.1", 1164 1166 ] 1165 1167 ··· 1171 1173 "ein_plugin", 1172 1174 "serde", 1173 1175 "serde_json", 1176 + "tempfile", 1174 1177 "wit-bindgen 0.53.1", 1175 1178 ] 1176 1179
+329 -1
crates/ein-agent/src/agents.rs
··· 520 520 use std::sync::Mutex; 521 521 522 522 use async_trait::async_trait; 523 - use ein_core::types::{Choice, CompletionResponse, ToolDef, ToolResult}; 523 + use ein_core::types::{Choice, CompletionResponse, ToolDef, ToolResult, Usage}; 524 524 525 525 use super::*; 526 526 ··· 691 691 let _ = agent.chat("user message").await.unwrap(); 692 692 693 693 assert_eq!(*tool.called_arg.lock().unwrap(), "test_val".to_string()); 694 + } 695 + 696 + // --------------------------------------------------------------------------- 697 + // Shared test helpers 698 + // --------------------------------------------------------------------------- 699 + 700 + fn tool_msg(id: &str, content: impl Into<String>) -> Message { 701 + Message { 702 + role: Role::Tool, 703 + content: Some(content.into()), 704 + tool_call_id: Some(id.to_string()), 705 + tool_calls: None, 706 + } 707 + } 708 + 709 + fn user_msg(content: impl Into<String>) -> Message { 710 + Message { 711 + role: Role::User, 712 + content: Some(content.into()), 713 + tool_calls: None, 714 + tool_call_id: None, 715 + } 716 + } 717 + 718 + fn system_msg(content: impl Into<String>) -> Message { 719 + Message { 720 + role: Role::System, 721 + content: Some(content.into()), 722 + tool_calls: None, 723 + tool_call_id: None, 724 + } 725 + } 726 + 727 + fn stop_response(content: &str) -> CompletionResponse { 728 + CompletionResponse { 729 + choices: vec![Choice { 730 + index: None, 731 + finish_reason: FinishReason::Stop, 732 + message: Message { 733 + role: Role::Assistant, 734 + content: Some(content.to_string()), 735 + tool_calls: None, 736 + tool_call_id: None, 737 + }, 738 + }], 739 + usage: None, 740 + error: None, 741 + } 742 + } 743 + 744 + // --------------------------------------------------------------------------- 745 + // truncate_old_tool_results — tested directly (private method, same file) 746 + // --------------------------------------------------------------------------- 747 + 748 + const TEST_THRESHOLD: usize = 50; 749 + const TEST_WINDOW: usize = 2; 750 + 751 + #[test] 752 + fn truncate_old_tool_results_replaces_large_stale_content() { 753 + let large = "x".repeat(TEST_THRESHOLD + 1); 754 + let history = vec![ 755 + tool_msg("t1", &large), 756 + tool_msg("t2", &large), 757 + user_msg("recent 1"), 758 + user_msg("recent 2"), 759 + ]; 760 + 761 + let mut agent = Agent::builder(basic_test_client()) 762 + .num_recent_messages(TEST_WINDOW) 763 + .max_tool_result_chars(TEST_THRESHOLD) 764 + .with_message_history(history) 765 + .build(); 766 + 767 + agent.truncate_old_tool_results(); 768 + 769 + let msgs = agent.messages(); 770 + assert!( 771 + msgs[0].content.as_deref().unwrap_or("").starts_with("[Tool result truncated:"), 772 + "old large tool result must be truncated" 773 + ); 774 + assert!( 775 + msgs[1].content.as_deref().unwrap_or("").starts_with("[Tool result truncated:"), 776 + "old large tool result must be truncated" 777 + ); 778 + assert_eq!(msgs[2].content.as_deref(), Some("recent 1")); 779 + assert_eq!(msgs[3].content.as_deref(), Some("recent 2")); 780 + } 781 + 782 + #[test] 783 + fn truncate_old_tool_results_keeps_recent_messages_intact() { 784 + let large = "x".repeat(TEST_THRESHOLD + 1); 785 + // All 3 messages are within the window of 3 — none should be truncated. 786 + let history = vec![ 787 + tool_msg("t1", &large), 788 + tool_msg("t2", &large), 789 + tool_msg("t3", &large), 790 + ]; 791 + 792 + let mut agent = Agent::builder(basic_test_client()) 793 + .num_recent_messages(3) 794 + .max_tool_result_chars(TEST_THRESHOLD) 795 + .with_message_history(history) 796 + .build(); 797 + 798 + agent.truncate_old_tool_results(); 799 + 800 + for msg in agent.messages() { 801 + assert!( 802 + !msg.content.as_deref().unwrap_or("").starts_with("[Tool result truncated:"), 803 + "recent messages must not be truncated" 804 + ); 805 + } 806 + } 807 + 808 + #[test] 809 + fn truncate_old_tool_results_ignores_non_tool_messages() { 810 + let large = "x".repeat(TEST_THRESHOLD + 1); 811 + let history = vec![ 812 + user_msg(&large), 813 + system_msg(&large), 814 + tool_msg("t1", "small"), 815 + ]; 816 + 817 + let mut agent = Agent::builder(basic_test_client()) 818 + .num_recent_messages(TEST_WINDOW) 819 + .max_tool_result_chars(TEST_THRESHOLD) 820 + .with_message_history(history) 821 + .build(); 822 + 823 + agent.truncate_old_tool_results(); 824 + 825 + let msgs = agent.messages(); 826 + assert_eq!(msgs[0].content.as_deref(), Some(large.as_str()), "User must not be truncated"); 827 + assert_eq!(msgs[1].content.as_deref(), Some(large.as_str()), "System must not be truncated"); 828 + } 829 + 830 + #[test] 831 + fn truncate_old_tool_results_skips_content_at_threshold() { 832 + // content length == threshold is NOT truncated (condition is strictly >) 833 + let at_threshold = "x".repeat(TEST_THRESHOLD); 834 + let history = vec![ 835 + tool_msg("t1", &at_threshold), 836 + user_msg("recent 1"), 837 + user_msg("recent 2"), 838 + ]; 839 + 840 + let mut agent = Agent::builder(basic_test_client()) 841 + .num_recent_messages(TEST_WINDOW) 842 + .max_tool_result_chars(TEST_THRESHOLD) 843 + .with_message_history(history) 844 + .build(); 845 + 846 + agent.truncate_old_tool_results(); 847 + 848 + assert_eq!( 849 + agent.messages()[0].content.as_deref(), 850 + Some(at_threshold.as_str()), 851 + "content exactly at threshold must not be truncated" 852 + ); 853 + } 854 + 855 + // --------------------------------------------------------------------------- 856 + // compact_history 857 + // --------------------------------------------------------------------------- 858 + 859 + #[tokio::test] 860 + async fn compact_history_returns_empty_when_no_user_messages() { 861 + let mut agent = Agent::builder(basic_test_client()) 862 + .with_message_history(vec![system_msg("you are helpful")]) 863 + .build(); 864 + 865 + let result = agent.compact_history().await.unwrap(); 866 + assert_eq!(result, "", "nothing to compact without user messages"); 867 + } 868 + 869 + #[tokio::test] 870 + async fn compact_history_replaces_history_with_system_plus_summary() { 871 + let summary = "Goals discussed, files modified, current state."; 872 + let mut agent = Agent::builder(BasicTestModelClient { 873 + response: stop_response(summary), 874 + }) 875 + .with_message_history(vec![system_msg("sys"), user_msg("do stuff")]) 876 + .build(); 877 + 878 + let returned = agent.compact_history().await.unwrap(); 879 + assert_eq!(returned, summary); 880 + 881 + let msgs = agent.messages(); 882 + assert_eq!(msgs.len(), 2, "original system + new summary system"); 883 + assert!(matches!(msgs[0].role, Role::System)); 884 + assert_eq!(msgs[0].content.as_deref(), Some("sys")); 885 + assert!(matches!(msgs[1].role, Role::System)); 886 + assert!(msgs[1].content.as_deref().unwrap_or("").contains(summary)); 887 + } 888 + 889 + #[tokio::test] 890 + async fn compact_history_broadcasts_content_delta_event() { 891 + use std::sync::Arc; 892 + 893 + let summary = "Detailed summary."; 894 + let captured: Arc<Mutex<Vec<AgentEvent>>> = Arc::new(Mutex::new(Vec::new())); 895 + let cap = captured.clone(); 896 + 897 + let mut agent = Agent::builder(BasicTestModelClient { 898 + response: stop_response(summary), 899 + }) 900 + .with_event_handler(move |event| { 901 + let cap = cap.clone(); 902 + async move { cap.lock().unwrap().push(event); } 903 + }) 904 + .with_message_history(vec![user_msg("do stuff")]) 905 + .build(); 906 + 907 + agent.compact_history().await.unwrap(); 908 + 909 + let events = captured.lock().unwrap(); 910 + let deltas: Vec<&str> = events 911 + .iter() 912 + .filter_map(|e| { 913 + if let AgentEvent::ContentDelta(t) = e { Some(t.as_str()) } else { None } 914 + }) 915 + .collect(); 916 + assert_eq!(deltas, vec![summary]); 917 + } 918 + 919 + // --------------------------------------------------------------------------- 920 + // chat error paths 921 + // --------------------------------------------------------------------------- 922 + 923 + #[tokio::test] 924 + async fn chat_returns_error_on_api_error_response() { 925 + let mut agent = Agent::builder(BasicTestModelClient { 926 + response: CompletionResponse { 927 + choices: vec![], 928 + usage: None, 929 + error: Some(serde_json::json!({"message": "insufficient credits"})), 930 + }, 931 + }) 932 + .build(); 933 + 934 + let err = agent.chat("prompt").await.unwrap_err(); 935 + assert!(matches!(err, AgentError::ModelClient(_))); 936 + assert!(err.to_string().contains("insufficient credits")); 937 + } 938 + 939 + #[tokio::test] 940 + async fn chat_returns_error_on_unsupported_finish_reason() { 941 + let mut agent = Agent::builder(BasicTestModelClient { 942 + response: CompletionResponse { 943 + choices: vec![Choice { 944 + index: None, 945 + finish_reason: FinishReason::Unsupported, 946 + message: Message { 947 + role: Role::Assistant, 948 + content: None, 949 + tool_calls: None, 950 + tool_call_id: None, 951 + }, 952 + }], 953 + usage: None, 954 + error: None, 955 + }, 956 + }) 957 + .build(); 958 + 959 + let err = agent.chat("prompt").await.unwrap_err(); 960 + assert!(matches!(err, AgentError::UnsupportedFinishReason(_))); 961 + } 962 + 963 + // --------------------------------------------------------------------------- 964 + // Token usage events and clear 965 + // --------------------------------------------------------------------------- 966 + 967 + #[tokio::test] 968 + async fn chat_emits_token_usage_event() { 969 + use std::sync::Arc; 970 + 971 + let captured: Arc<Mutex<Vec<AgentEvent>>> = Arc::new(Mutex::new(Vec::new())); 972 + let cap = captured.clone(); 973 + 974 + let mut agent = Agent::builder(BasicTestModelClient { 975 + response: CompletionResponse { 976 + choices: vec![Choice { 977 + index: None, 978 + finish_reason: FinishReason::Stop, 979 + message: Message { 980 + role: Role::Assistant, 981 + content: Some("done".to_string()), 982 + tool_calls: None, 983 + tool_call_id: None, 984 + }, 985 + }], 986 + usage: Some(Usage { 987 + prompt_tokens: 10, 988 + completion_tokens: 5, 989 + total_tokens: 15, 990 + }), 991 + error: None, 992 + }, 993 + }) 994 + .with_event_handler(move |event| { 995 + let cap = cap.clone(); 996 + async move { cap.lock().unwrap().push(event); } 997 + }) 998 + .build(); 999 + 1000 + agent.chat("hello").await.unwrap(); 1001 + 1002 + let events = captured.lock().unwrap(); 1003 + let usage = events.iter().find_map(|e| { 1004 + if let AgentEvent::TokenUsage { prompt_tokens, completion_tokens, total_tokens } = e { 1005 + Some((*prompt_tokens, *completion_tokens, *total_tokens)) 1006 + } else { 1007 + None 1008 + } 1009 + }); 1010 + assert_eq!(usage, Some((10, 5, 15)), "TokenUsage event must carry correct totals"); 1011 + } 1012 + 1013 + #[tokio::test] 1014 + async fn clear_messages_empties_history() { 1015 + let mut agent = Agent::builder(basic_test_client()) 1016 + .with_message_history(vec![system_msg("sys"), user_msg("hello")]) 1017 + .build(); 1018 + 1019 + assert!(!agent.messages().is_empty()); 1020 + agent.clear_messages(); 1021 + assert!(agent.messages().is_empty()); 694 1022 } 695 1023 }
+163
crates/ein-core/src/types.rs
··· 262 262 } 263 263 } 264 264 } 265 + 266 + #[cfg(test)] 267 + mod tests { 268 + use super::*; 269 + 270 + // --------------------------------------------------------------------------- 271 + // ToolFunctionBuilder 272 + // --------------------------------------------------------------------------- 273 + 274 + #[test] 275 + fn builder_produces_function_tool_def() { 276 + let tool = ToolDef::function("my_tool", "does a thing").build(); 277 + assert!(matches!(tool, ToolDef::Function { .. })); 278 + } 279 + 280 + #[test] 281 + fn builder_required_param_appears_in_required_array() { 282 + let tool = ToolDef::function("t", "d") 283 + .param("req_param", "string", "a required param", true) 284 + .build(); 285 + 286 + let ToolDef::Function { function } = tool; 287 + let ToolFunctionParams::Object { required, .. } = &function.parameters; 288 + assert!(required.contains(&"req_param".to_string())); 289 + } 290 + 291 + #[test] 292 + fn builder_optional_param_absent_from_required_array() { 293 + let tool = ToolDef::function("t", "d") 294 + .param("opt_param", "string", "an optional param", false) 295 + .build(); 296 + 297 + let ToolDef::Function { function } = tool; 298 + let ToolFunctionParams::Object { required, .. } = &function.parameters; 299 + assert!(!required.contains(&"opt_param".to_string())); 300 + } 301 + 302 + #[test] 303 + fn builder_param_appears_in_properties() { 304 + let tool = ToolDef::function("t", "d") 305 + .param("my_arg", "integer", "an arg", true) 306 + .build(); 307 + 308 + let ToolDef::Function { function } = tool; 309 + let ToolFunctionParams::Object { properties, .. } = &function.parameters; 310 + assert!(properties.props().contains_key("my_arg")); 311 + } 312 + 313 + #[test] 314 + fn builder_serializes_to_openai_schema_shape() { 315 + let tool = ToolDef::function("bash", "run a shell command") 316 + .param("command", "string", "the command to run", true) 317 + .build(); 318 + 319 + let v = serde_json::to_value(&tool).unwrap(); 320 + assert_eq!(v["type"], "function"); 321 + assert_eq!(v["function"]["name"], "bash"); 322 + assert_eq!(v["function"]["parameters"]["type"], "object"); 323 + assert_eq!( 324 + v["function"]["parameters"]["properties"]["command"]["type"], 325 + "string" 326 + ); 327 + assert_eq!(v["function"]["parameters"]["required"][0], "command"); 328 + } 329 + 330 + // --------------------------------------------------------------------------- 331 + // Message serialization 332 + // --------------------------------------------------------------------------- 333 + 334 + #[test] 335 + fn message_round_trips_for_each_role() { 336 + for role in [Role::System, Role::User, Role::Assistant, Role::Tool] { 337 + let msg = Message { 338 + role: role.clone(), 339 + content: Some("hello".to_string()), 340 + tool_calls: None, 341 + tool_call_id: None, 342 + }; 343 + let json = serde_json::to_string(&msg).unwrap(); 344 + let decoded: Message = serde_json::from_str(&json).unwrap(); 345 + assert_eq!(decoded.role, role); 346 + assert_eq!(decoded.content.as_deref(), Some("hello")); 347 + } 348 + } 349 + 350 + #[test] 351 + fn message_omits_none_fields_from_json() { 352 + let msg = Message { 353 + role: Role::User, 354 + content: Some("hi".to_string()), 355 + tool_calls: None, 356 + tool_call_id: None, 357 + }; 358 + let v = serde_json::to_value(&msg).unwrap(); 359 + assert!(!v.as_object().unwrap().contains_key("tool_calls")); 360 + assert!(!v.as_object().unwrap().contains_key("tool_call_id")); 361 + } 362 + 363 + // --------------------------------------------------------------------------- 364 + // FinishReason 365 + // --------------------------------------------------------------------------- 366 + 367 + #[test] 368 + fn finish_reason_unknown_value_deserializes_as_unsupported() { 369 + let json = r#""length""#; 370 + let reason: FinishReason = serde_json::from_str(json).unwrap(); 371 + assert!(matches!(reason, FinishReason::Unsupported)); 372 + } 373 + 374 + #[test] 375 + fn finish_reason_known_values_round_trip() { 376 + for (s, expected) in [ 377 + ("\"stop\"", FinishReason::Stop), 378 + ("\"tool_calls\"", FinishReason::ToolCalls), 379 + ] { 380 + let reason: FinishReason = serde_json::from_str(s).unwrap(); 381 + assert!(matches!( 382 + (reason, expected), 383 + (FinishReason::Stop, FinishReason::Stop) 384 + | (FinishReason::ToolCalls, FinishReason::ToolCalls) 385 + )); 386 + } 387 + } 388 + 389 + // --------------------------------------------------------------------------- 390 + // CompletionRequest / CompletionResponse 391 + // --------------------------------------------------------------------------- 392 + 393 + #[test] 394 + fn completion_request_round_trips() { 395 + let req = CompletionRequest { 396 + model: "gpt-4o".to_string(), 397 + max_tokens: 1024, 398 + messages: vec![Message { 399 + role: Role::User, 400 + content: Some("hello".to_string()), 401 + tool_calls: None, 402 + tool_call_id: None, 403 + }], 404 + tools: vec![], 405 + }; 406 + let json = serde_json::to_string(&req).unwrap(); 407 + let decoded: CompletionRequest = serde_json::from_str(&json).unwrap(); 408 + assert_eq!(decoded.model, "gpt-4o"); 409 + assert_eq!(decoded.max_tokens, 1024); 410 + assert_eq!(decoded.messages.len(), 1); 411 + } 412 + 413 + #[test] 414 + fn completion_response_error_field_round_trips() { 415 + let resp = CompletionResponse { 416 + choices: vec![], 417 + usage: None, 418 + error: Some(serde_json::json!({"message": "insufficient credits"})), 419 + }; 420 + let json = serde_json::to_string(&resp).unwrap(); 421 + let decoded: CompletionResponse = serde_json::from_str(&json).unwrap(); 422 + assert_eq!( 423 + decoded.error.unwrap()["message"], 424 + "insufficient credits" 425 + ); 426 + } 427 + }
+92
crates/ein-tui/src/connection.rs
··· 237 237 Ok(resp.into_inner()) 238 238 } 239 239 240 + // --------------------------------------------------------------------------- 241 + // Tests 242 + // --------------------------------------------------------------------------- 243 + 244 + #[cfg(test)] 245 + mod tests { 246 + use std::collections::HashMap; 247 + 248 + use super::*; 249 + use crate::config::{ClientConfig, PluginConfig}; 250 + 251 + #[test] 252 + fn to_proto_maps_basic_fields() { 253 + let cfg = ClientConfig { 254 + allowed_paths: vec!["/home/user".to_string()], 255 + allowed_hosts: vec!["openrouter.ai".to_string()], 256 + model_client_name: "ein_openrouter".to_string(), 257 + plugin_configs: HashMap::new(), 258 + }; 259 + 260 + let proto = to_proto_session_config(&cfg, "sess-123".to_string()); 261 + 262 + assert_eq!(proto.allowed_paths, vec!["/home/user"]); 263 + assert_eq!(proto.allowed_hosts, vec!["openrouter.ai"]); 264 + assert_eq!(proto.model_client_name, "ein_openrouter"); 265 + assert_eq!(proto.session_id, "sess-123"); 266 + } 267 + 268 + #[test] 269 + fn to_proto_empty_plugin_configs_produces_empty_map() { 270 + let proto = to_proto_session_config(&ClientConfig::default(), "id".to_string()); 271 + assert!(proto.plugin_configs.is_empty()); 272 + } 273 + 274 + #[test] 275 + fn to_proto_serializes_plugin_params_as_json() { 276 + let mut params = HashMap::new(); 277 + params.insert("api_key".to_string(), serde_json::json!("sk-test")); 278 + params.insert("model".to_string(), serde_json::json!("my-model")); 279 + 280 + let mut plugin_configs = HashMap::new(); 281 + plugin_configs.insert( 282 + "ein_openrouter".to_string(), 283 + PluginConfig { params, ..Default::default() }, 284 + ); 285 + 286 + let cfg = ClientConfig { plugin_configs, ..Default::default() }; 287 + let proto = to_proto_session_config(&cfg, "id".to_string()); 288 + 289 + let pc = &proto.plugin_configs["ein_openrouter"]; 290 + let parsed: serde_json::Value = serde_json::from_str(&pc.params_json).unwrap(); 291 + assert_eq!(parsed["api_key"].as_str().unwrap(), "sk-test"); 292 + assert_eq!(parsed["model"].as_str().unwrap(), "my-model"); 293 + } 294 + 295 + #[test] 296 + fn to_proto_maps_plugin_allowed_paths_and_hosts() { 297 + let mut plugin_configs = HashMap::new(); 298 + plugin_configs.insert( 299 + "Bash".to_string(), 300 + PluginConfig { 301 + allowed_paths: vec!["/tmp".to_string()], 302 + allowed_hosts: vec!["example.com".to_string()], 303 + params: HashMap::new(), 304 + }, 305 + ); 306 + 307 + let cfg = ClientConfig { plugin_configs, ..Default::default() }; 308 + let proto = to_proto_session_config(&cfg, "id".to_string()); 309 + 310 + let pc = &proto.plugin_configs["Bash"]; 311 + assert_eq!(pc.allowed_paths, vec!["/tmp"]); 312 + assert_eq!(pc.allowed_hosts, vec!["example.com"]); 313 + } 314 + 315 + #[test] 316 + fn to_proto_multiple_plugin_configs_all_present() { 317 + let mut plugin_configs = HashMap::new(); 318 + plugin_configs.insert("ein_openrouter".to_string(), PluginConfig::default()); 319 + plugin_configs.insert("Bash".to_string(), PluginConfig::default()); 320 + plugin_configs.insert("Read".to_string(), PluginConfig::default()); 321 + 322 + let cfg = ClientConfig { plugin_configs, ..Default::default() }; 323 + let proto = to_proto_session_config(&cfg, "id".to_string()); 324 + 325 + assert_eq!(proto.plugin_configs.len(), 3); 326 + assert!(proto.plugin_configs.contains_key("ein_openrouter")); 327 + assert!(proto.plugin_configs.contains_key("Bash")); 328 + assert!(proto.plugin_configs.contains_key("Read")); 329 + } 330 + } 331 + 240 332 /// Opens a short-lived connection and deletes a session by ID. 241 333 /// 242 334 /// Returns `Ok(())` on success; errors are logged by the caller.
+219
crates/ein-tui/src/input.rs
··· 846 846 assert_eq!(app.cumulative_tokens, 42); 847 847 } 848 848 } 849 + 850 + #[cfg(test)] 851 + mod key_events { 852 + use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; 853 + 854 + use crate::app::{DisplayMessage, test_helpers::make_app}; 855 + 856 + use super::*; 857 + 858 + fn key(code: KeyCode) -> KeyEvent { 859 + KeyEvent::new(code, KeyModifiers::NONE) 860 + } 861 + 862 + fn ctrl(code: KeyCode) -> KeyEvent { 863 + KeyEvent::new(code, KeyModifiers::CONTROL) 864 + } 865 + 866 + // --------------------------------------------------------------------------- 867 + // Ctrl-C and slash commands 868 + // --------------------------------------------------------------------------- 869 + 870 + #[tokio::test] 871 + async fn ctrl_c_always_quits() { 872 + let mut app = make_app("m"); 873 + let action = handle_key_event(&mut app, ctrl(KeyCode::Char('c'))).await; 874 + assert!(matches!(action, KeyAction::Quit)); 875 + } 876 + 877 + #[tokio::test] 878 + async fn exit_command_quits() { 879 + let mut app = make_app("m"); 880 + app.input = "/exit".to_string(); 881 + let action = handle_key_event(&mut app, key(KeyCode::Enter)).await; 882 + assert!(matches!(action, KeyAction::Quit)); 883 + } 884 + 885 + #[tokio::test] 886 + async fn new_command_returns_new_session() { 887 + let mut app = make_app("m"); 888 + app.input = "/new".to_string(); 889 + let action = handle_key_event(&mut app, key(KeyCode::Enter)).await; 890 + assert!(matches!(action, KeyAction::NewSession)); 891 + } 892 + 893 + #[tokio::test] 894 + async fn sessions_command_opens_picker() { 895 + let mut app = make_app("m"); 896 + app.input = "/sessions".to_string(); 897 + let action = handle_key_event(&mut app, key(KeyCode::Enter)).await; 898 + assert!(matches!(action, KeyAction::OpenSessionPicker)); 899 + } 900 + 901 + #[tokio::test] 902 + async fn plugins_command_opens_plugin_modal() { 903 + let mut app = make_app("m"); 904 + app.input = "/plugins".to_string(); 905 + let action = handle_key_event(&mut app, key(KeyCode::Enter)).await; 906 + assert!(matches!(action, KeyAction::OpenPluginModal)); 907 + } 908 + 909 + #[tokio::test] 910 + async fn unknown_slash_command_shows_error_message() { 911 + let mut app = make_app("m"); 912 + app.input = "/doesnotexist".to_string(); 913 + let action = handle_key_event(&mut app, key(KeyCode::Enter)).await; 914 + assert!(matches!(action, KeyAction::Continue)); 915 + assert!( 916 + app.messages.iter().any(|m| matches!(m, DisplayMessage::Error(_))), 917 + "unknown slash command must add an Error message" 918 + ); 919 + } 920 + 921 + #[tokio::test] 922 + async fn enter_with_empty_input_is_a_no_op() { 923 + let mut app = make_app("m"); 924 + let initial_msg_count = app.messages.len(); 925 + let action = handle_key_event(&mut app, key(KeyCode::Enter)).await; 926 + assert!(matches!(action, KeyAction::Continue)); 927 + assert_eq!(app.messages.len(), initial_msg_count); 928 + } 929 + 930 + // --------------------------------------------------------------------------- 931 + // Text input editing 932 + // --------------------------------------------------------------------------- 933 + 934 + #[tokio::test] 935 + async fn char_key_appends_to_input_and_advances_cursor() { 936 + let mut app = make_app("m"); 937 + let _ = handle_key_event(&mut app, key(KeyCode::Char('h'))).await; 938 + let _ = handle_key_event(&mut app, key(KeyCode::Char('i'))).await; 939 + assert_eq!(app.input, "hi"); 940 + assert_eq!(app.cursor_pos, 2); 941 + } 942 + 943 + #[tokio::test] 944 + async fn backspace_removes_last_char() { 945 + let mut app = make_app("m"); 946 + app.input = "hello".to_string(); 947 + app.cursor_pos = 5; 948 + let _ = handle_key_event(&mut app, key(KeyCode::Backspace)).await; 949 + assert_eq!(app.input, "hell"); 950 + assert_eq!(app.cursor_pos, 4); 951 + } 952 + 953 + #[tokio::test] 954 + async fn backspace_at_start_of_input_is_a_no_op() { 955 + let mut app = make_app("m"); 956 + app.input = "hi".to_string(); 957 + app.cursor_pos = 0; 958 + let _ = handle_key_event(&mut app, key(KeyCode::Backspace)).await; 959 + assert_eq!(app.input, "hi"); 960 + assert_eq!(app.cursor_pos, 0); 961 + } 962 + 963 + #[tokio::test] 964 + async fn enter_clears_input_after_command() { 965 + let mut app = make_app("m"); 966 + app.input = "/exit".to_string(); 967 + let _ = handle_key_event(&mut app, key(KeyCode::Enter)).await; 968 + // input was consumed (taken) by the Enter handler 969 + assert!(app.input.is_empty()); 970 + } 971 + 972 + // --------------------------------------------------------------------------- 973 + // Scrolling 974 + // --------------------------------------------------------------------------- 975 + 976 + #[tokio::test] 977 + async fn up_arrow_increments_scroll_offset_and_disables_autoscroll() { 978 + let mut app = make_app("m"); 979 + app.auto_scroll = true; 980 + app.scroll_offset = 0; 981 + let _ = handle_key_event(&mut app, key(KeyCode::Up)).await; 982 + assert_eq!(app.scroll_offset, 1); 983 + assert!(!app.auto_scroll); 984 + } 985 + 986 + #[tokio::test] 987 + async fn down_arrow_at_bottom_re_enables_autoscroll() { 988 + let mut app = make_app("m"); 989 + app.scroll_offset = 0; 990 + app.auto_scroll = false; 991 + let _ = handle_key_event(&mut app, key(KeyCode::Down)).await; 992 + assert!(app.auto_scroll); 993 + } 994 + 995 + #[tokio::test] 996 + async fn down_arrow_above_bottom_decrements_scroll_offset() { 997 + let mut app = make_app("m"); 998 + app.scroll_offset = 5; 999 + let _ = handle_key_event(&mut app, key(KeyCode::Down)).await; 1000 + assert_eq!(app.scroll_offset, 4); 1001 + } 1002 + 1003 + // --------------------------------------------------------------------------- 1004 + // update_autocomplete 1005 + // --------------------------------------------------------------------------- 1006 + 1007 + #[test] 1008 + fn autocomplete_activates_for_slash_prefix() { 1009 + let mut app = make_app("m"); 1010 + app.input = "/ex".to_string(); 1011 + update_autocomplete(&mut app); 1012 + assert!(app.autocomplete_active); 1013 + assert!(!app.autocomplete_matches.is_empty()); 1014 + } 1015 + 1016 + #[test] 1017 + fn autocomplete_slash_alone_matches_all_commands() { 1018 + let mut app = make_app("m"); 1019 + app.input = "/".to_string(); 1020 + update_autocomplete(&mut app); 1021 + assert!(app.autocomplete_active); 1022 + assert_eq!(app.autocomplete_matches.len(), COMMANDS.len()); 1023 + } 1024 + 1025 + #[test] 1026 + fn autocomplete_deactivates_for_non_slash_input() { 1027 + let mut app = make_app("m"); 1028 + app.input = "hello".to_string(); 1029 + app.autocomplete_active = true; 1030 + app.autocomplete_matches = vec![0]; 1031 + update_autocomplete(&mut app); 1032 + assert!(!app.autocomplete_active); 1033 + assert!(app.autocomplete_matches.is_empty()); 1034 + } 1035 + 1036 + #[test] 1037 + fn autocomplete_no_match_leaves_inactive() { 1038 + let mut app = make_app("m"); 1039 + app.input = "/zzz".to_string(); 1040 + update_autocomplete(&mut app); 1041 + assert!(!app.autocomplete_active); 1042 + } 1043 + 1044 + // --------------------------------------------------------------------------- 1045 + // char_to_byte_idx (private helper, accessible from same-file test module) 1046 + // --------------------------------------------------------------------------- 1047 + 1048 + #[test] 1049 + fn char_to_byte_idx_ascii() { 1050 + assert_eq!(char_to_byte_idx("hello", 0), 0); 1051 + assert_eq!(char_to_byte_idx("hello", 3), 3); 1052 + assert_eq!(char_to_byte_idx("hello", 5), 5); 1053 + } 1054 + 1055 + #[test] 1056 + fn char_to_byte_idx_multibyte() { 1057 + let s = "héllo"; // é is 2 bytes (U+00E9) 1058 + assert_eq!(char_to_byte_idx(s, 0), 0); 1059 + assert_eq!(char_to_byte_idx(s, 1), 1); // 'h' 1060 + assert_eq!(char_to_byte_idx(s, 2), 3); // after 'é' (2 bytes) 1061 + } 1062 + 1063 + #[test] 1064 + fn char_to_byte_idx_past_end_returns_len() { 1065 + assert_eq!(char_to_byte_idx("hi", 99), 2); 1066 + } 1067 + }
+3
packages/ein_edit/Cargo.toml
··· 17 17 serde = { workspace = true } 18 18 serde_json = { workspace = true } 19 19 wit-bindgen = { workspace = true } 20 + 21 + [dev-dependencies] 22 + tempfile = "3"
+122
packages/ein_edit/src/lib.rs
··· 78 78 79 79 #[cfg(target_arch = "wasm32")] 80 80 ein_plugin::export_tool!(EditTool); 81 + 82 + #[cfg(test)] 83 + mod tests { 84 + use super::*; 85 + use std::io::Write; 86 + use tempfile::NamedTempFile; 87 + 88 + fn tool() -> EditTool { 89 + EditTool::new() 90 + } 91 + 92 + fn call(path: &str, old: &str, new: &str) -> anyhow::Result<ToolResult> { 93 + let args = serde_json::json!({ 94 + "file_path": path, 95 + "old_string": old, 96 + "new_string": new, 97 + }) 98 + .to_string(); 99 + tool().call("id", &args) 100 + } 101 + 102 + fn write_temp(content: &str) -> NamedTempFile { 103 + let mut f = NamedTempFile::new().unwrap(); 104 + f.write_all(content.as_bytes()).unwrap(); 105 + f 106 + } 107 + 108 + // --------------------------------------------------------------------------- 109 + // Replacement behaviour 110 + // --------------------------------------------------------------------------- 111 + 112 + #[test] 113 + fn edit_replaces_first_occurrence_only() { 114 + let f = write_temp("hello world world"); 115 + let result = call(f.path().to_str().unwrap(), "world", "Rust").unwrap(); 116 + let on_disk = fs::read_to_string(f.path()).unwrap(); 117 + assert_eq!(on_disk, "hello Rust world"); 118 + assert!(result.content.contains("Successfully edited")); 119 + } 120 + 121 + #[test] 122 + fn edit_empty_new_string_deletes_matched_text() { 123 + let f = write_temp("remove_me keep"); 124 + call(f.path().to_str().unwrap(), "remove_me ", "").unwrap(); 125 + assert_eq!(fs::read_to_string(f.path()).unwrap(), "keep"); 126 + } 127 + 128 + #[test] 129 + fn edit_multiline_old_string_replaced_correctly() { 130 + let f = write_temp("line1\nline2\nline3\n"); 131 + call(f.path().to_str().unwrap(), "line1\nline2", "replaced").unwrap(); 132 + assert_eq!(fs::read_to_string(f.path()).unwrap(), "replaced\nline3\n"); 133 + } 134 + 135 + // --------------------------------------------------------------------------- 136 + // Metadata: start_line 137 + // --------------------------------------------------------------------------- 138 + 139 + #[test] 140 + fn edit_start_line_is_1_for_match_at_top() { 141 + let f = write_temp("target here\nother line\n"); 142 + let result = call(f.path().to_str().unwrap(), "target", "X").unwrap(); 143 + let meta = result.metadata.unwrap(); 144 + assert_eq!(meta["start_line"], 1); 145 + } 146 + 147 + #[test] 148 + fn edit_start_line_accounts_for_preceding_newlines() { 149 + let f = write_temp("line1\nline2\ntarget\nline4\n"); 150 + let result = call(f.path().to_str().unwrap(), "target", "X").unwrap(); 151 + let meta = result.metadata.unwrap(); 152 + assert_eq!(meta["start_line"], 3); 153 + } 154 + 155 + #[test] 156 + fn edit_multiline_start_line_is_line_of_first_character() { 157 + let f = write_temp("a\nb\nc\nd\n"); 158 + // old_string starts at line 2 159 + let result = call(f.path().to_str().unwrap(), "b\nc", "X").unwrap(); 160 + let meta = result.metadata.unwrap(); 161 + assert_eq!(meta["start_line"], 2); 162 + } 163 + 164 + // --------------------------------------------------------------------------- 165 + // Metadata: old_lines / new_lines 166 + // --------------------------------------------------------------------------- 167 + 168 + #[test] 169 + fn edit_metadata_contains_old_and_new_lines() { 170 + let f = write_temp("foo\nbar\nbaz\n"); 171 + let result = call(f.path().to_str().unwrap(), "bar", "qux").unwrap(); 172 + let meta = result.metadata.unwrap(); 173 + assert_eq!(meta["old_lines"], serde_json::json!(["bar"])); 174 + assert_eq!(meta["new_lines"], serde_json::json!(["qux"])); 175 + } 176 + 177 + #[test] 178 + fn edit_metadata_multiline_old_and_new_lines() { 179 + let f = write_temp("a\nb\nc\n"); 180 + let result = call(f.path().to_str().unwrap(), "a\nb", "x\ny\nz").unwrap(); 181 + let meta = result.metadata.unwrap(); 182 + assert_eq!(meta["old_lines"], serde_json::json!(["a", "b"])); 183 + assert_eq!(meta["new_lines"], serde_json::json!(["x", "y", "z"])); 184 + } 185 + 186 + // --------------------------------------------------------------------------- 187 + // Error cases 188 + // --------------------------------------------------------------------------- 189 + 190 + #[test] 191 + fn edit_returns_error_when_old_string_not_found() { 192 + let f = write_temp("hello world"); 193 + let err = call(f.path().to_str().unwrap(), "no such string", "x").unwrap_err(); 194 + assert!(err.to_string().contains("not found"), "got: {err}"); 195 + } 196 + 197 + #[test] 198 + fn edit_returns_error_for_missing_file() { 199 + let err = call("/nonexistent/path/file.txt", "x", "y").unwrap_err(); 200 + assert!(err.to_string().contains("No such file") || err.to_string().contains("os error")); 201 + } 202 + }
+207 -36
packages/ein_ollama/src/lib.rs
··· 51 51 "http://localhost:11434/v1".to_string() 52 52 } 53 53 54 + fn inject_num_ctx(body: &mut serde_json::Value, num_ctx: Option<u32>) { 55 + if let Some(n) = num_ctx { 56 + body["options"] = serde_json::json!({ "num_ctx": n }); 57 + } 58 + } 59 + 60 + fn map_http_error(status: u16, body: &str, model: &str) -> Option<anyhow::Error> { 61 + match status { 62 + 401 => { 63 + let msg = extract_api_error(body).unwrap_or_else(|| "Unauthorized".to_owned()); 64 + Some(anyhow!( 65 + "{msg}\n\n\ 66 + Most local Ollama instances do not require authentication.\n\ 67 + If your deployment uses a bearer token, set it in \ 68 + ~/.ein/config.json under \ 69 + plugin_configs.ein_ollama.params.api_key" 70 + )) 71 + } 72 + 402 => { 73 + let msg = 74 + extract_api_error(body).unwrap_or_else(|| "Payment required".to_owned()); 75 + Some(anyhow!("{msg}")) 76 + } 77 + 404 => { 78 + let msg = 79 + extract_api_error(body).unwrap_or_else(|| "Model not found".to_owned()); 80 + Some(anyhow!( 81 + "{msg}\n\n\ 82 + The model may not be downloaded yet. Run:\n\ 83 + ollama pull {model}\n\ 84 + To list available models: ollama list" 85 + )) 86 + } 87 + s if !(200..300).contains(&s) => { 88 + let msg = extract_api_error(body).unwrap_or_else(|| format!("HTTP {s}")); 89 + Some(anyhow!("API error: {msg}")) 90 + } 91 + _ => None, 92 + } 93 + } 94 + 54 95 pub struct OllamaPlugin { 55 96 config: OllamaConfig, 56 97 } ··· 77 118 // Inject Ollama-specific options (e.g. num_ctx) alongside the standard 78 119 // fields if configured. 79 120 let mut body = serde_json::to_value(&req)?; 80 - if let Some(num_ctx) = self.config.num_ctx { 81 - eprintln!("[ollama] setting num_ctx={num_ctx}"); 82 - body["options"] = serde_json::json!({ "num_ctx": num_ctx }); 121 + if self.config.num_ctx.is_some() { 122 + eprintln!("[ollama] setting num_ctx={}", self.config.num_ctx.unwrap()); 83 123 } 124 + inject_num_ctx(&mut body, self.config.num_ctx); 84 125 85 126 let mut req_builder = HttpRequest::post(url); 86 127 if let Some(key) = &self.config.api_key { ··· 112 153 } 113 154 })?; 114 155 115 - match resp.status { 116 - 401 => { 117 - let msg = 118 - extract_api_error(&resp.body).unwrap_or_else(|| "Unauthorized".to_owned()); 119 - return Err(anyhow!( 120 - "{msg}\n\n\ 121 - Most local Ollama instances do not require authentication.\n\ 122 - If your deployment uses a bearer token, set it in \ 123 - ~/.ein/config.json under \ 124 - plugin_configs.ein_ollama.params.api_key" 125 - )); 126 - } 127 - 402 => { 128 - let msg = 129 - extract_api_error(&resp.body).unwrap_or_else(|| "Payment required".to_owned()); 130 - return Err(anyhow!("{msg}")); 131 - } 132 - 404 => { 133 - let msg = 134 - extract_api_error(&resp.body).unwrap_or_else(|| "Model not found".to_owned()); 135 - return Err(anyhow!( 136 - "{msg}\n\n\ 137 - The model may not be downloaded yet. Run:\n\ 138 - ollama pull {}\n\ 139 - To list available models: ollama list", 140 - req.model 141 - )); 142 - } 143 - s if !(200..300).contains(&s) => { 144 - let msg = extract_api_error(&resp.body).unwrap_or_else(|| format!("HTTP {s}")); 145 - return Err(anyhow!("API error: {msg}")); 146 - } 147 - _ => {} 156 + if let Some(e) = map_http_error(resp.status, &resp.body, &req.model) { 157 + return Err(e); 148 158 } 149 159 150 160 // Validate the body parses as CompletionResponse before returning. ··· 162 172 163 173 #[cfg(target_arch = "wasm32")] 164 174 ein_plugin::export_model_client!(OllamaPlugin); 175 + 176 + #[cfg(test)] 177 + mod tests { 178 + use super::*; 179 + use ein_plugin::model_client::CompletionResponse; 180 + use serde_json::json; 181 + 182 + // --------------------------------------------------------------------------- 183 + // extract_api_error 184 + // --------------------------------------------------------------------------- 185 + 186 + #[test] 187 + fn extract_api_error_present() { 188 + let body = r#"{"error": {"message": "model not loaded", "type": "not_found"}}"#; 189 + assert_eq!( 190 + extract_api_error(body).as_deref(), 191 + Some("model not loaded") 192 + ); 193 + } 194 + 195 + #[test] 196 + fn extract_api_error_missing_error_key() { 197 + assert!(extract_api_error(r#"{"choices": []}"#).is_none()); 198 + } 199 + 200 + #[test] 201 + fn extract_api_error_malformed_json() { 202 + assert!(extract_api_error("not json").is_none()); 203 + } 204 + 205 + // --------------------------------------------------------------------------- 206 + // OllamaConfig deserialization 207 + // --------------------------------------------------------------------------- 208 + 209 + #[test] 210 + fn config_default_base_url() { 211 + let cfg: OllamaConfig = serde_json::from_value(json!({})).unwrap(); 212 + assert_eq!(cfg.base_url, "http://localhost:11434/v1"); 213 + } 214 + 215 + #[test] 216 + fn config_absent_api_key_is_none() { 217 + let cfg: OllamaConfig = serde_json::from_value(json!({})).unwrap(); 218 + assert!(cfg.api_key.is_none()); 219 + } 220 + 221 + #[test] 222 + fn config_empty_api_key_treated_as_none() { 223 + let cfg: OllamaConfig = serde_json::from_value(json!({"api_key": ""})).unwrap(); 224 + assert!(cfg.api_key.is_none()); 225 + } 226 + 227 + #[test] 228 + fn config_valid_api_key() { 229 + let cfg: OllamaConfig = serde_json::from_value(json!({"api_key": "tok"})).unwrap(); 230 + assert_eq!(cfg.api_key.as_deref(), Some("tok")); 231 + } 232 + 233 + #[test] 234 + fn config_num_ctx_absent_is_none() { 235 + let cfg: OllamaConfig = serde_json::from_value(json!({})).unwrap(); 236 + assert!(cfg.num_ctx.is_none()); 237 + } 238 + 239 + #[test] 240 + fn config_num_ctx_set() { 241 + let cfg: OllamaConfig = serde_json::from_value(json!({"num_ctx": 16384})).unwrap(); 242 + assert_eq!(cfg.num_ctx, Some(16384)); 243 + } 244 + 245 + // --------------------------------------------------------------------------- 246 + // inject_num_ctx 247 + // --------------------------------------------------------------------------- 248 + 249 + #[test] 250 + fn num_ctx_injected_into_body() { 251 + let mut body = json!({"model": "llama3", "messages": []}); 252 + inject_num_ctx(&mut body, Some(8192)); 253 + assert_eq!(body["options"]["num_ctx"], 8192); 254 + } 255 + 256 + #[test] 257 + fn num_ctx_not_injected_when_absent() { 258 + let mut body = json!({"model": "llama3", "messages": []}); 259 + inject_num_ctx(&mut body, None); 260 + assert!(body.get("options").is_none()); 261 + } 262 + 263 + // --------------------------------------------------------------------------- 264 + // map_http_error 265 + // --------------------------------------------------------------------------- 266 + 267 + #[test] 268 + fn map_http_error_401_contains_api_key_hint() { 269 + let err = map_http_error(401, "{}", "llama3").unwrap(); 270 + let msg = err.to_string(); 271 + assert!(msg.contains("api_key"), "expected api_key hint in: {msg}"); 272 + } 273 + 274 + #[test] 275 + fn map_http_error_401_includes_api_message() { 276 + let body = r#"{"error": {"message": "Invalid token"}}"#; 277 + let err = map_http_error(401, body, "llama3").unwrap(); 278 + assert!(err.to_string().contains("Invalid token")); 279 + } 280 + 281 + #[test] 282 + fn map_http_error_404_suggests_ollama_pull() { 283 + let err = map_http_error(404, "{}", "mistral").unwrap(); 284 + let msg = err.to_string(); 285 + assert!(msg.contains("ollama pull"), "expected 'ollama pull' in: {msg}"); 286 + assert!(msg.contains("mistral"), "expected model name in: {msg}"); 287 + } 288 + 289 + #[test] 290 + fn map_http_error_404_passes_through_api_message() { 291 + let body = r#"{"error": {"message": "model 'qwen' not found"}}"#; 292 + let err = map_http_error(404, body, "qwen").unwrap(); 293 + assert!(err.to_string().contains("model 'qwen' not found")); 294 + } 295 + 296 + #[test] 297 + fn map_http_error_other_non_2xx() { 298 + let err = map_http_error(503, "{}", "llama3").unwrap(); 299 + let msg = err.to_string(); 300 + assert!(msg.contains("503"), "expected status code in: {msg}"); 301 + } 302 + 303 + #[test] 304 + fn map_http_error_2xx_returns_none() { 305 + assert!(map_http_error(200, "{}", "llama3").is_none()); 306 + assert!(map_http_error(201, "{}", "llama3").is_none()); 307 + } 308 + 309 + // --------------------------------------------------------------------------- 310 + // Response body validation 311 + // --------------------------------------------------------------------------- 312 + 313 + #[test] 314 + fn valid_completion_response_parses() { 315 + let body = r#"{ 316 + "id": "ollama-gen-1", 317 + "object": "chat.completion", 318 + "model": "llama3", 319 + "choices": [{ 320 + "index": 0, 321 + "finish_reason": "stop", 322 + "message": {"role": "assistant", "content": "Hello!"} 323 + }], 324 + "usage": {"prompt_tokens": 8, "completion_tokens": 3, "total_tokens": 11} 325 + }"#; 326 + let resp: Result<CompletionResponse, _> = serde_json::from_str(body); 327 + assert!(resp.is_ok(), "expected valid response to parse: {:?}", resp); 328 + } 329 + 330 + #[test] 331 + fn invalid_response_body_returns_error() { 332 + let resp: Result<CompletionResponse, _> = serde_json::from_str("not valid json"); 333 + assert!(resp.is_err()); 334 + } 335 + }
+168 -29
packages/ein_openrouter/src/lib.rs
··· 26 26 .and_then(|v| v.get("error")?.get("message")?.as_str().map(str::to_owned)) 27 27 } 28 28 29 + fn map_http_error(status: u16, body: &str) -> Option<anyhow::Error> { 30 + match status { 31 + 401 => { 32 + let msg = 33 + extract_api_error(body).unwrap_or_else(|| "Invalid or missing API key".to_owned()); 34 + Some(anyhow!( 35 + "{msg}\n\n\ 36 + Set your api_key in ~/.ein/config.json under \ 37 + plugin_configs.ein_openrouter.params.api_key" 38 + )) 39 + } 40 + 402 => { 41 + let msg = 42 + extract_api_error(body).unwrap_or_else(|| "Insufficient credits".to_owned()); 43 + Some(anyhow!( 44 + "{msg}\n\nCheck your account balance at openrouter.ai." 45 + )) 46 + } 47 + 404 => { 48 + let msg = 49 + extract_api_error(body).unwrap_or_else(|| "Resource not found".to_owned()); 50 + Some(anyhow!("{msg}")) 51 + } 52 + s if !(200..300).contains(&s) => { 53 + let msg = extract_api_error(body).unwrap_or_default(); 54 + Some(anyhow!("Status HTTP {s}. API error: {msg}")) 55 + } 56 + _ => None, 57 + } 58 + } 59 + 29 60 #[derive(Deserialize)] 30 61 struct OpenRouterConfig { 31 62 api_key: String, ··· 65 96 .send() 66 97 .map_err(|e| anyhow!("Could not connect to {}: {e}", self.config.base_url))?; 67 98 68 - match resp.status { 69 - 401 => { 70 - let msg = extract_api_error(&resp.body) 71 - .unwrap_or_else(|| "Invalid or missing API key".to_owned()); 72 - return Err(anyhow!( 73 - "{msg}\n\n\ 74 - Set your api_key in ~/.ein/config.json under \ 75 - plugin_configs.ein_openrouter.params.api_key" 76 - )); 77 - } 78 - 402 => { 79 - let msg = extract_api_error(&resp.body) 80 - .unwrap_or_else(|| "Insufficient credits".to_owned()); 81 - return Err(anyhow!( 82 - "{msg}\n\nCheck your account balance at openrouter.ai." 83 - )); 84 - } 85 - 404 => { 86 - let msg = extract_api_error(&resp.body) 87 - .unwrap_or_else(|| "Resource not found".to_owned()); 88 - return Err(anyhow!("{msg}")); 89 - } 90 - s if !(200..300).contains(&s) => { 91 - let status = format!("HTTP {s}"); 92 - let msg = extract_api_error(&resp.body).unwrap_or_default(); 93 - 94 - return Err(anyhow!("Status {status}. API error: {msg}")); 95 - } 96 - _ => {} 99 + if let Some(e) = map_http_error(resp.status, &resp.body) { 100 + return Err(e); 97 101 } 98 102 99 103 // Validate the body parses as CompletionResponse before returning. ··· 111 115 112 116 #[cfg(target_arch = "wasm32")] 113 117 ein_plugin::export_model_client!(OpenRouterPlugin); 118 + 119 + #[cfg(test)] 120 + mod tests { 121 + use super::*; 122 + use ein_plugin::model_client::CompletionResponse; 123 + use serde_json::json; 124 + 125 + // --------------------------------------------------------------------------- 126 + // extract_api_error 127 + // --------------------------------------------------------------------------- 128 + 129 + #[test] 130 + fn extract_api_error_present() { 131 + let body = r#"{"error": {"message": "You exceeded your quota", "type": "quota_exceeded"}}"#; 132 + assert_eq!( 133 + extract_api_error(body).as_deref(), 134 + Some("You exceeded your quota") 135 + ); 136 + } 137 + 138 + #[test] 139 + fn extract_api_error_missing_error_key() { 140 + assert!(extract_api_error(r#"{"choices": []}"#).is_none()); 141 + } 142 + 143 + #[test] 144 + fn extract_api_error_missing_message_key() { 145 + assert!(extract_api_error(r#"{"error": {"type": "server_error"}}"#).is_none()); 146 + } 147 + 148 + #[test] 149 + fn extract_api_error_malformed_json() { 150 + assert!(extract_api_error("not json at all").is_none()); 151 + } 152 + 153 + // --------------------------------------------------------------------------- 154 + // OpenRouterConfig deserialization 155 + // --------------------------------------------------------------------------- 156 + 157 + #[test] 158 + fn config_default_base_url() { 159 + let cfg: OpenRouterConfig = 160 + serde_json::from_value(json!({"api_key": "sk-or-test"})).unwrap(); 161 + assert_eq!(cfg.api_key, "sk-or-test"); 162 + assert_eq!(cfg.base_url, "https://openrouter.ai/api/v1"); 163 + } 164 + 165 + #[test] 166 + fn config_custom_base_url() { 167 + let cfg: OpenRouterConfig = serde_json::from_value(json!({ 168 + "api_key": "sk-or-test", 169 + "base_url": "https://my-proxy.example.com/v1" 170 + })) 171 + .unwrap(); 172 + assert_eq!(cfg.base_url, "https://my-proxy.example.com/v1"); 173 + } 174 + 175 + // --------------------------------------------------------------------------- 176 + // map_http_error 177 + // --------------------------------------------------------------------------- 178 + 179 + #[test] 180 + fn map_http_error_401_contains_api_key_hint() { 181 + let err = map_http_error(401, "{}").unwrap(); 182 + let msg = err.to_string(); 183 + assert!(msg.contains("api_key"), "expected api_key hint in: {msg}"); 184 + } 185 + 186 + #[test] 187 + fn map_http_error_401_includes_api_message() { 188 + let body = r#"{"error": {"message": "Incorrect API key provided"}}"#; 189 + let err = map_http_error(401, body).unwrap(); 190 + assert!(err.to_string().contains("Incorrect API key provided")); 191 + } 192 + 193 + #[test] 194 + fn map_http_error_402_mentions_credits_and_balance() { 195 + let err = map_http_error(402, "{}").unwrap(); 196 + let msg = err.to_string(); 197 + assert!(msg.contains("openrouter.ai"), "expected openrouter.ai link in: {msg}"); 198 + } 199 + 200 + #[test] 201 + fn map_http_error_404_passes_through_api_message() { 202 + let body = r#"{"error": {"message": "No endpoints found for model"}}"#; 203 + let err = map_http_error(404, body).unwrap(); 204 + assert!(err.to_string().contains("No endpoints found for model")); 205 + } 206 + 207 + #[test] 208 + fn map_http_error_404_fallback_when_no_api_message() { 209 + let err = map_http_error(404, "{}").unwrap(); 210 + assert!(err.to_string().contains("Resource not found")); 211 + } 212 + 213 + #[test] 214 + fn map_http_error_other_non_2xx() { 215 + let err = map_http_error(503, "{}").unwrap(); 216 + let msg = err.to_string(); 217 + assert!(msg.contains("503"), "expected status code in: {msg}"); 218 + } 219 + 220 + #[test] 221 + fn map_http_error_2xx_returns_none() { 222 + assert!(map_http_error(200, "{}").is_none()); 223 + assert!(map_http_error(201, "{}").is_none()); 224 + } 225 + 226 + // --------------------------------------------------------------------------- 227 + // Response body validation 228 + // --------------------------------------------------------------------------- 229 + 230 + #[test] 231 + fn valid_completion_response_parses() { 232 + let body = r#"{ 233 + "id": "gen-abc", 234 + "object": "chat.completion", 235 + "model": "anthropic/claude-haiku-4-5", 236 + "choices": [{ 237 + "index": 0, 238 + "finish_reason": "stop", 239 + "message": {"role": "assistant", "content": "Hello!"} 240 + }], 241 + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} 242 + }"#; 243 + let resp: Result<CompletionResponse, _> = serde_json::from_str(body); 244 + assert!(resp.is_ok(), "expected valid response to parse: {:?}", resp); 245 + } 246 + 247 + #[test] 248 + fn invalid_response_body_returns_error() { 249 + let resp: Result<CompletionResponse, _> = serde_json::from_str("not valid json"); 250 + assert!(resp.is_err()); 251 + } 252 + }
+3
packages/ein_read/Cargo.toml
··· 17 17 serde = { workspace = true } 18 18 serde_json = { workspace = true } 19 19 wit-bindgen = { workspace = true} 20 + 21 + [dev-dependencies] 22 + tempfile = "3"
+125
packages/ein_read/src/lib.rs
··· 62 62 } 63 63 64 64 fn call(&self, id: &str, args: &str) -> anyhow::Result<ToolResult> { 65 + #[cfg(target_arch = "wasm32")] 65 66 ein_plugin::tool::syscalls::log(&format!("Reading file with args: {args}")); 66 67 67 68 let args: ReadArgs = serde_json::from_str(args)?; ··· 94 95 95 96 #[cfg(target_arch = "wasm32")] 96 97 ein_plugin::export_tool!(ReadTool); 98 + 99 + #[cfg(test)] 100 + mod tests { 101 + use super::*; 102 + use std::io::Write; 103 + use tempfile::NamedTempFile; 104 + 105 + fn tool() -> ReadTool { 106 + ReadTool::new() 107 + } 108 + 109 + fn call(path: &str, offset: Option<usize>, limit: Option<usize>) -> anyhow::Result<String> { 110 + let mut args = serde_json::json!({ "file_path": path }); 111 + if let Some(o) = offset { 112 + args["offset"] = serde_json::json!(o); 113 + } 114 + if let Some(l) = limit { 115 + args["limit"] = serde_json::json!(l); 116 + } 117 + tool().call("id", &args.to_string()).map(|r| r.content) 118 + } 119 + 120 + fn write_temp(content: &str) -> NamedTempFile { 121 + let mut f = NamedTempFile::new().unwrap(); 122 + f.write_all(content.as_bytes()).unwrap(); 123 + f 124 + } 125 + 126 + // --------------------------------------------------------------------------- 127 + // Basic reading 128 + // --------------------------------------------------------------------------- 129 + 130 + #[test] 131 + fn read_returns_all_lines_by_default() { 132 + let f = write_temp("a\nb\nc\nd\ne\n"); 133 + let out = call(f.path().to_str().unwrap(), None, None).unwrap(); 134 + assert_eq!(out, "a\nb\nc\nd\ne"); 135 + } 136 + 137 + #[test] 138 + fn read_empty_file_returns_empty_string() { 139 + let f = write_temp(""); 140 + let out = call(f.path().to_str().unwrap(), None, None).unwrap(); 141 + assert_eq!(out, ""); 142 + } 143 + 144 + // --------------------------------------------------------------------------- 145 + // Offset / limit windowing 146 + // --------------------------------------------------------------------------- 147 + 148 + #[test] 149 + fn read_with_offset_skips_lines() { 150 + let f = write_temp("line1\nline2\nline3\nline4\n"); 151 + let out = call(f.path().to_str().unwrap(), Some(2), None).unwrap(); 152 + assert_eq!(out, "line3\nline4"); 153 + } 154 + 155 + #[test] 156 + fn read_with_limit_caps_output() { 157 + let f = write_temp("a\nb\nc\nd\ne\n"); 158 + let out = call(f.path().to_str().unwrap(), None, Some(3)).unwrap(); 159 + // truncated — must start with the header 160 + assert!(out.starts_with("Lines 1-3 of 5"), "got: {out}"); 161 + assert!(out.contains("a\nb\nc")); 162 + } 163 + 164 + #[test] 165 + fn read_offset_and_limit_combined() { 166 + let f = write_temp("a\nb\nc\nd\ne\n"); 167 + let out = call(f.path().to_str().unwrap(), Some(1), Some(2)).unwrap(); 168 + // offset=1, limit=2 → lines b, c; 2 more remain → truncation header 169 + assert!(out.contains("b\nc"), "got: {out}"); 170 + } 171 + 172 + #[test] 173 + fn read_offset_beyond_file_end_returns_empty() { 174 + let f = write_temp("only one line\n"); 175 + let out = call(f.path().to_str().unwrap(), Some(99), None).unwrap(); 176 + assert_eq!(out, ""); 177 + } 178 + 179 + #[test] 180 + fn read_limit_equal_to_line_count_produces_no_header() { 181 + let f = write_temp("x\ny\nz\n"); 182 + let out = call(f.path().to_str().unwrap(), None, Some(3)).unwrap(); 183 + // exactly fits — no truncation header 184 + assert!(!out.contains("Lines"), "got: {out}"); 185 + assert_eq!(out, "x\ny\nz"); 186 + } 187 + 188 + // --------------------------------------------------------------------------- 189 + // Truncation header format 190 + // --------------------------------------------------------------------------- 191 + 192 + #[test] 193 + fn read_truncation_header_shows_range_and_total() { 194 + // 10 lines, read only first 4 195 + let content = (1..=10).map(|i| format!("line{i}")).collect::<Vec<_>>().join("\n"); 196 + let f = write_temp(&content); 197 + let out = call(f.path().to_str().unwrap(), None, Some(4)).unwrap(); 198 + assert!(out.starts_with("Lines 1-4 of 10"), "got: {out}"); 199 + assert!(out.contains("use offset=4 to read more"), "got: {out}"); 200 + } 201 + 202 + #[test] 203 + fn read_truncation_header_reflects_offset() { 204 + let content = (1..=10).map(|i| format!("line{i}")).collect::<Vec<_>>().join("\n"); 205 + let f = write_temp(&content); 206 + let out = call(f.path().to_str().unwrap(), Some(3), Some(3)).unwrap(); 207 + // offset=3, limit=3 → lines 4-6 (1-based), 4 remain → header 208 + assert!(out.starts_with("Lines 4-6 of 10"), "got: {out}"); 209 + assert!(out.contains("use offset=6 to read more"), "got: {out}"); 210 + } 211 + 212 + // --------------------------------------------------------------------------- 213 + // Error cases 214 + // --------------------------------------------------------------------------- 215 + 216 + #[test] 217 + fn read_returns_error_for_missing_file() { 218 + let err = call("/nonexistent/path/file.txt", None, None).unwrap_err(); 219 + assert!(err.to_string().contains("No such file") || err.to_string().contains("os error")); 220 + } 221 + }
+3
packages/ein_write/Cargo.toml
··· 17 17 serde = { workspace = true } 18 18 serde_json = { workspace = true } 19 19 wit-bindgen = { workspace = true} 20 + 21 + [dev-dependencies] 22 + tempfile = "3"
+57
packages/ein_write/src/lib.rs
··· 81 81 82 82 #[cfg(target_arch = "wasm32")] 83 83 ein_plugin::export_tool!(WriteTool); 84 + 85 + #[cfg(test)] 86 + mod tests { 87 + use super::*; 88 + use tempfile::TempDir; 89 + 90 + fn tool() -> WriteTool { 91 + WriteTool::new() 92 + } 93 + 94 + fn call(path: &str, content: &str) -> anyhow::Result<ToolResult> { 95 + let args = serde_json::json!({ "file_path": path, "content": content }).to_string(); 96 + tool().call("id", &args) 97 + } 98 + 99 + #[test] 100 + fn write_creates_file_with_content() { 101 + let dir = TempDir::new().unwrap(); 102 + let path = dir.path().join("out.txt"); 103 + call(path.to_str().unwrap(), "hello world").unwrap(); 104 + assert_eq!(fs::read_to_string(&path).unwrap(), "hello world"); 105 + } 106 + 107 + #[test] 108 + fn write_overwrites_existing_file() { 109 + let dir = TempDir::new().unwrap(); 110 + let path = dir.path().join("out.txt"); 111 + call(path.to_str().unwrap(), "first").unwrap(); 112 + call(path.to_str().unwrap(), "second").unwrap(); 113 + assert_eq!(fs::read_to_string(&path).unwrap(), "second"); 114 + } 115 + 116 + #[test] 117 + fn write_creates_parent_directories() { 118 + let dir = TempDir::new().unwrap(); 119 + let path = dir.path().join("a").join("b").join("c").join("file.txt"); 120 + call(path.to_str().unwrap(), "nested").unwrap(); 121 + assert_eq!(fs::read_to_string(&path).unwrap(), "nested"); 122 + } 123 + 124 + #[test] 125 + fn write_empty_content_creates_empty_file() { 126 + let dir = TempDir::new().unwrap(); 127 + let path = dir.path().join("empty.txt"); 128 + call(path.to_str().unwrap(), "").unwrap(); 129 + assert_eq!(fs::read_to_string(&path).unwrap(), ""); 130 + } 131 + 132 + #[test] 133 + fn write_returns_success_message_with_byte_count_and_path() { 134 + let dir = TempDir::new().unwrap(); 135 + let path = dir.path().join("f.txt"); 136 + let result = call(path.to_str().unwrap(), "abc").unwrap(); 137 + assert!(result.content.contains('3'), "expected byte count in: {}", result.content); 138 + assert!(result.content.contains(path.to_str().unwrap()), "expected path in: {}", result.content); 139 + } 140 + }