A lexicon-driven AppView for ATProto. happyview.dev
backfill firehose jetstream atproto appview oauth lexicon
8
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: add xrpc libs for Lua

Trezy a65cd4b2 7c5def86

+731 -10
+8 -3
src/auth/middleware.rs
··· 34 34 self.client_key.as_deref() 35 35 } 36 36 37 - /// Test-only constructor. 38 - #[cfg(test)] 39 - pub fn new_for_test(did: String) -> Self { 37 + /// Create claims for an internal call (e.g. Lua xrpc lib) with no client key. 38 + pub fn internal(did: String) -> Self { 40 39 Self { 41 40 did, 42 41 client_key: None, 43 42 } 43 + } 44 + 45 + /// Test-only constructor. 46 + #[cfg(test)] 47 + pub fn new_for_test(did: String) -> Self { 48 + Self::internal(did) 44 49 } 45 50 } 46 51
+55
src/lua/execute.rs
··· 163 163 return Err(AppError::Internal(error_message)); 164 164 } 165 165 166 + if let Err(e) = 167 + super::xrpc_api::register_xrpc_api(&lua, state_arc.clone(), Some(claims.did().to_string())) 168 + { 169 + let error_message = format!("failed to register xrpc API: {e}"); 170 + log_event( 171 + &state.db, 172 + EventLog { 173 + event_type: "script.error".to_string(), 174 + severity: Severity::Error, 175 + actor_did: Some(claims.did().to_string()), 176 + subject: Some(method.to_string()), 177 + detail: serde_json::json!({ 178 + "error": error_message, 179 + "script_source": script_source, 180 + "input": input_json, 181 + "caller_did": claims.did(), 182 + "method": method, 183 + "duration_ms": start.elapsed().as_millis() as u64, 184 + }), 185 + }, 186 + backend, 187 + ) 188 + .await; 189 + return Err(AppError::Internal(error_message)); 190 + } 191 + 166 192 if let Err(e) = atproto_api::register_atproto_api(&lua, state_arc.clone(), Some(claims.did())) { 167 193 let error_message = format!("failed to register atproto API: {e}"); 168 194 log_event( ··· 520 546 return Err(AppError::Internal(error_message)); 521 547 } 522 548 549 + if let Err(e) = super::xrpc_api::register_xrpc_api( 550 + &lua, 551 + state_arc.clone(), 552 + claims.map(|c| c.did().to_string()), 553 + ) { 554 + let error_message = format!("failed to register xrpc API: {e}"); 555 + log_event( 556 + &state.db, 557 + EventLog { 558 + event_type: "script.error".to_string(), 559 + severity: Severity::Error, 560 + actor_did: None, 561 + subject: Some(method.to_string()), 562 + detail: serde_json::json!({ 563 + "error": error_message, 564 + "script_source": script_source, 565 + "method": method, 566 + "duration_ms": start.elapsed().as_millis() as u64, 567 + }), 568 + }, 569 + backend, 570 + ) 571 + .await; 572 + return Err(AppError::Internal(error_message)); 573 + } 574 + 523 575 if let Err(e) = atproto_api::register_atproto_api(&lua, state_arc, claims.map(|c| c.did())) { 524 576 let error_message = format!("failed to register atproto API: {e}"); 525 577 log_event( ··· 890 942 891 943 http_api::register_http_api(&lua, state_arc.clone()) 892 944 .map_err(|e| format!("failed to register http API: {e}"))?; 945 + 946 + super::xrpc_api::register_xrpc_api(&lua, state_arc.clone(), Some(event.did.to_string())) 947 + .map_err(|e| format!("failed to register xrpc API: {e}"))?; 893 948 894 949 atproto_api::register_atproto_api(&lua, state_arc, None) 895 950 .map_err(|e| format!("failed to register atproto API: {e}"))?;
+1
src/lua/mod.rs
··· 6 6 mod record; 7 7 pub(crate) mod sandbox; 8 8 mod tid; 9 + mod xrpc_api; 9 10 10 11 pub(crate) use execute::{ 11 12 HookEvent, execute_hook_script, execute_procedure_script, execute_query_script,
+660
src/lua/xrpc_api.rs
··· 1 + use axum::response::Response; 2 + use http_body_util::BodyExt; 3 + use mlua::{Lua, LuaSerdeExt, Result as LuaResult}; 4 + use serde_json::Value; 5 + use std::collections::HashMap; 6 + use std::sync::Arc; 7 + 8 + use crate::AppState; 9 + use crate::auth::Claims; 10 + use crate::lexicon::LexiconType; 11 + use crate::xrpc; 12 + 13 + /// Convert an axum Response into a Lua table with `{ status, body }`. 14 + async fn response_to_lua_table(lua: &Lua, response: Response) -> LuaResult<mlua::Table> { 15 + let status = response.status().as_u16(); 16 + let body_bytes = response 17 + .into_body() 18 + .collect() 19 + .await 20 + .map_err(|e| mlua::Error::runtime(format!("failed to read response body: {e}")))? 21 + .to_bytes(); 22 + let body = String::from_utf8_lossy(&body_bytes).to_string(); 23 + 24 + let table = lua.create_table()?; 25 + table.set("status", status)?; 26 + table.set("body", body)?; 27 + Ok(table) 28 + } 29 + 30 + /// Convert an Option<mlua::Table> of params into a HashMap<String, Value>. 31 + fn lua_table_to_params(lua: &Lua, table: Option<mlua::Table>) -> LuaResult<HashMap<String, Value>> { 32 + match table { 33 + Some(t) => { 34 + let value: Value = lua.from_value(mlua::Value::Table(t))?; 35 + match value { 36 + Value::Object(map) => Ok(map.into_iter().collect()), 37 + _ => Ok(HashMap::new()), 38 + } 39 + } 40 + None => Ok(HashMap::new()), 41 + } 42 + } 43 + 44 + pub fn register_xrpc_api( 45 + lua: &Lua, 46 + state: Arc<AppState>, 47 + caller_did: Option<String>, 48 + ) -> LuaResult<()> { 49 + let xrpc_table = lua.create_table()?; 50 + 51 + // xrpc.query(method, params?) 52 + { 53 + let state = state.clone(); 54 + let caller_did = caller_did.clone(); 55 + let func = lua.create_async_function( 56 + move |lua, (method, params): (String, Option<mlua::Table>)| { 57 + let state = state.clone(); 58 + let caller_did = caller_did.clone(); 59 + async move { 60 + let mut params = lua_table_to_params(&lua, params)?; 61 + let claims = caller_did.map(Claims::internal); 62 + 63 + let response = 64 + execute_local_query(&state, &method, &mut params, claims.as_ref()) 65 + .await 66 + .map_err(|e| mlua::Error::runtime(format!("xrpc query failed: {e}")))?; 67 + 68 + response_to_lua_table(&lua, response).await 69 + } 70 + }, 71 + )?; 72 + xrpc_table.set("query", func)?; 73 + } 74 + 75 + // xrpc.procedure(method, input, params?) 76 + { 77 + let state = state.clone(); 78 + let caller_did = caller_did.clone(); 79 + let func = lua.create_async_function( 80 + move |lua, (method, input, params): (String, mlua::Value, Option<mlua::Table>)| { 81 + let state = state.clone(); 82 + let caller_did = caller_did.clone(); 83 + async move { 84 + let mut params = lua_table_to_params(&lua, params)?; 85 + let input: Value = lua.from_value(input)?; 86 + let claims = caller_did.clone().map(Claims::internal).ok_or_else(|| { 87 + mlua::Error::runtime( 88 + "xrpc.procedure requires authentication (no caller_did)", 89 + ) 90 + })?; 91 + 92 + let response = 93 + execute_local_procedure(&state, &method, &claims, &input, &mut params) 94 + .await 95 + .map_err(|e| { 96 + mlua::Error::runtime(format!("xrpc procedure failed: {e}")) 97 + })?; 98 + 99 + response_to_lua_table(&lua, response).await 100 + } 101 + }, 102 + )?; 103 + xrpc_table.set("procedure", func)?; 104 + } 105 + 106 + lua.globals().set("xrpc", xrpc_table)?; 107 + Ok(()) 108 + } 109 + 110 + /// Execute a query XRPC — local handler if known, proxy if not. 111 + async fn execute_local_query( 112 + state: &AppState, 113 + method: &str, 114 + params: &mut HashMap<String, Value>, 115 + claims: Option<&Claims>, 116 + ) -> Result<Response, crate::error::AppError> { 117 + let lexicon = state.lexicons.get(method).await; 118 + 119 + match lexicon { 120 + Some(lex) => { 121 + if lex.lexicon_type != LexiconType::Query { 122 + return Err(crate::error::AppError::BadRequest(format!( 123 + "{method} is not a query endpoint" 124 + ))); 125 + } 126 + if let Some(ref param_schema) = lex.parameters { 127 + xrpc::coerce_params(params, param_schema); 128 + } 129 + xrpc::query::handle_query(state, method, params, &lex, claims).await 130 + } 131 + None => { 132 + let query_string = params_to_query_string(params); 133 + xrpc::proxy_to_authority(state, method, &query_string, None).await 134 + } 135 + } 136 + } 137 + 138 + /// Build a query string from a params HashMap (used by proxy path). 139 + fn params_to_query_string(params: &HashMap<String, Value>) -> String { 140 + params 141 + .iter() 142 + .map(|(k, v)| { 143 + let val = match v { 144 + Value::String(s) => s.clone(), 145 + other => other.to_string(), 146 + }; 147 + format!("{}={}", urlencoding::encode(k), urlencoding::encode(&val)) 148 + }) 149 + .collect::<Vec<_>>() 150 + .join("&") 151 + } 152 + 153 + /// Execute a procedure XRPC — local handler if known, proxy if not. 154 + async fn execute_local_procedure( 155 + state: &AppState, 156 + method: &str, 157 + claims: &Claims, 158 + input: &Value, 159 + params: &mut HashMap<String, Value>, 160 + ) -> Result<Response, crate::error::AppError> { 161 + let lexicon = state.lexicons.get(method).await; 162 + 163 + match lexicon { 164 + Some(lex) => { 165 + if lex.lexicon_type != LexiconType::Procedure { 166 + return Err(crate::error::AppError::BadRequest(format!( 167 + "{method} is not a procedure endpoint" 168 + ))); 169 + } 170 + if let Some(ref param_schema) = lex.parameters { 171 + xrpc::coerce_params(params, param_schema); 172 + } 173 + xrpc::procedure::handle_procedure(state, method, claims, input, params, &lex).await 174 + } 175 + None => { 176 + let query_string = params_to_query_string(params); 177 + xrpc::proxy_to_authority(state, method, &query_string, Some(input)).await 178 + } 179 + } 180 + } 181 + 182 + #[cfg(test)] 183 + mod tests { 184 + use super::*; 185 + use crate::config::Config; 186 + use crate::db::DatabaseBackend; 187 + use crate::lexicon::{LexiconRegistry, LexiconType, ParsedLexicon, ProcedureAction}; 188 + use crate::lua::sandbox; 189 + use mlua::LuaSerdeExt; 190 + use serde_json::json; 191 + use tokio::sync::watch; 192 + 193 + fn test_state() -> AppState { 194 + let config = Config { 195 + host: "127.0.0.1".into(), 196 + port: 3000, 197 + database_url: String::new(), 198 + database_backend: DatabaseBackend::Sqlite, 199 + public_url: String::new(), 200 + session_secret: "test-secret".into(), 201 + jetstream_url: String::new(), 202 + relay_url: String::new(), 203 + plc_url: String::new(), 204 + static_dir: String::new(), 205 + event_log_retention_days: 30, 206 + app_name: None, 207 + logo_uri: None, 208 + tos_uri: None, 209 + policy_uri: None, 210 + token_encryption_key: None, 211 + default_rate_limit_capacity: 100, 212 + default_rate_limit_refill_rate: 2.0, 213 + }; 214 + let (tx, _) = watch::channel(vec![]); 215 + let (labeler_tx, _) = watch::channel(()); 216 + sqlx::any::install_default_drivers(); 217 + let test_db = sqlx::AnyPool::connect_lazy("sqlite::memory:").unwrap(); 218 + let atrium_http = std::sync::Arc::new(atrium_oauth::DefaultHttpClient::default()); 219 + let did_resolver = atrium_identity::did::CommonDidResolver::new( 220 + atrium_identity::did::CommonDidResolverConfig { 221 + plc_directory_url: "https://plc.directory".into(), 222 + http_client: std::sync::Arc::clone(&atrium_http), 223 + }, 224 + ); 225 + let handle_resolver = atrium_identity::handle::AtprotoHandleResolver::new( 226 + atrium_identity::handle::AtprotoHandleResolverConfig { 227 + dns_txt_resolver: crate::dns::NativeDnsResolver::new(), 228 + http_client: atrium_http, 229 + }, 230 + ); 231 + let oauth = atrium_oauth::OAuthClient::new(atrium_oauth::OAuthClientConfig { 232 + client_metadata: atrium_oauth::AtprotoLocalhostClientMetadata { 233 + redirect_uris: Some(vec!["http://127.0.0.1:0/auth/callback".into()]), 234 + scopes: Some(vec![atrium_oauth::Scope::Known( 235 + atrium_oauth::KnownScope::Atproto, 236 + )]), 237 + }, 238 + keys: None, 239 + state_store: crate::auth::oauth_store::DbStateStore::new( 240 + test_db.clone(), 241 + DatabaseBackend::Sqlite, 242 + ), 243 + session_store: crate::auth::oauth_store::DbSessionStore::new( 244 + test_db.clone(), 245 + DatabaseBackend::Sqlite, 246 + ), 247 + resolver: atrium_oauth::OAuthResolverConfig { 248 + did_resolver, 249 + handle_resolver, 250 + authorization_server_metadata: Default::default(), 251 + protected_resource_metadata: Default::default(), 252 + }, 253 + }) 254 + .expect("Failed to create test OAuth client"); 255 + AppState { 256 + config, 257 + http: reqwest::Client::new(), 258 + db: test_db.clone(), 259 + db_backend: DatabaseBackend::Sqlite, 260 + lexicons: LexiconRegistry::new(), 261 + collections_tx: tx, 262 + labeler_subscriptions_tx: labeler_tx, 263 + rate_limiter: crate::rate_limit::RateLimiter::new( 264 + false, 265 + crate::rate_limit::RateLimitConfig { 266 + capacity: 100, 267 + refill_rate: 2.0, 268 + default_query_cost: 1, 269 + default_procedure_cost: 1, 270 + default_proxy_cost: 1, 271 + }, 272 + ), 273 + oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new( 274 + oauth, 275 + ))), 276 + oauth_state_store: crate::auth::oauth_store::DbStateStore::new( 277 + test_db.clone(), 278 + DatabaseBackend::Sqlite, 279 + ), 280 + cookie_key: axum_extra::extract::cookie::Key::derive_from( 281 + b"test-secret-for-tests-only-not-production", 282 + ), 283 + plugin_registry: std::sync::Arc::new(crate::plugin::PluginRegistry::new()), 284 + wasm_runtime: std::sync::Arc::new( 285 + crate::plugin::WasmRuntime::new().expect("wasm runtime"), 286 + ), 287 + attestation_signer: None, 288 + } 289 + } 290 + 291 + fn make_query_lexicon(id: &str, script: Option<&str>) -> ParsedLexicon { 292 + ParsedLexicon { 293 + id: id.to_string(), 294 + lexicon_type: LexiconType::Query, 295 + record_key: None, 296 + parameters: None, 297 + input: None, 298 + output: None, 299 + record_schema: None, 300 + raw: json!({}), 301 + revision: 1, 302 + target_collection: None, 303 + action: ProcedureAction::Create, 304 + script: script.map(|s| s.to_string()), 305 + index_hook: None, 306 + token_cost: None, 307 + } 308 + } 309 + 310 + fn make_procedure_lexicon(id: &str, script: Option<&str>) -> ParsedLexicon { 311 + ParsedLexicon { 312 + id: id.to_string(), 313 + lexicon_type: LexiconType::Procedure, 314 + record_key: None, 315 + parameters: None, 316 + input: None, 317 + output: None, 318 + record_schema: None, 319 + raw: json!({}), 320 + revision: 1, 321 + target_collection: None, 322 + action: ProcedureAction::Create, 323 + script: script.map(|s| s.to_string()), 324 + index_hook: None, 325 + token_cost: None, 326 + } 327 + } 328 + 329 + // ----------------------------------------------------------------------- 330 + // Registration 331 + // ----------------------------------------------------------------------- 332 + 333 + #[tokio::test] 334 + async fn register_xrpc_api_creates_global() { 335 + let state = Arc::new(test_state()); 336 + let lua = sandbox::create_sandbox().unwrap(); 337 + register_xrpc_api(&lua, state, Some("did:plc:test".into())).unwrap(); 338 + 339 + let xrpc: mlua::Table = lua.globals().get("xrpc").unwrap(); 340 + assert!(xrpc.get::<mlua::Function>("query").is_ok()); 341 + assert!(xrpc.get::<mlua::Function>("procedure").is_ok()); 342 + } 343 + 344 + #[tokio::test] 345 + async fn register_xrpc_api_without_caller_did() { 346 + let state = Arc::new(test_state()); 347 + let lua = sandbox::create_sandbox().unwrap(); 348 + register_xrpc_api(&lua, state, None).unwrap(); 349 + 350 + let xrpc: mlua::Table = lua.globals().get("xrpc").unwrap(); 351 + assert!(xrpc.get::<mlua::Function>("query").is_ok()); 352 + assert!(xrpc.get::<mlua::Function>("procedure").is_ok()); 353 + } 354 + 355 + // ----------------------------------------------------------------------- 356 + // lua_table_to_params 357 + // ----------------------------------------------------------------------- 358 + 359 + #[test] 360 + fn lua_table_to_params_none_returns_empty() { 361 + let lua = sandbox::create_sandbox().unwrap(); 362 + let result = lua_table_to_params(&lua, None).unwrap(); 363 + assert!(result.is_empty()); 364 + } 365 + 366 + #[test] 367 + fn lua_table_to_params_converts_string_values() { 368 + let lua = sandbox::create_sandbox().unwrap(); 369 + let table = lua.create_table().unwrap(); 370 + table.set("handle", "user.bsky.social").unwrap(); 371 + table.set("limit", 10).unwrap(); 372 + 373 + let result = lua_table_to_params(&lua, Some(table)).unwrap(); 374 + assert_eq!(result.get("handle").unwrap(), "user.bsky.social"); 375 + assert_eq!(result.get("limit").unwrap(), 10); 376 + } 377 + 378 + // ----------------------------------------------------------------------- 379 + // params_to_query_string 380 + // ----------------------------------------------------------------------- 381 + 382 + #[test] 383 + fn params_to_query_string_empty() { 384 + let params = HashMap::new(); 385 + assert_eq!(params_to_query_string(&params), ""); 386 + } 387 + 388 + #[test] 389 + fn params_to_query_string_encodes_values() { 390 + let mut params = HashMap::new(); 391 + params.insert("handle".into(), Value::String("user.bsky.social".into())); 392 + let qs = params_to_query_string(&params); 393 + assert!(qs.contains("handle=user.bsky.social")); 394 + } 395 + 396 + #[test] 397 + fn params_to_query_string_url_encodes_special_chars() { 398 + let mut params = HashMap::new(); 399 + params.insert( 400 + "uri".into(), 401 + Value::String("at://did:plc:abc/col/rkey".into()), 402 + ); 403 + let qs = params_to_query_string(&params); 404 + assert!(qs.contains("uri=at%3A%2F%2Fdid%3Aplc%3Aabc%2Fcol%2Frkey")); 405 + } 406 + 407 + // ----------------------------------------------------------------------- 408 + // execute_local_query 409 + // ----------------------------------------------------------------------- 410 + 411 + #[tokio::test] 412 + async fn query_local_script_returns_json() { 413 + let state = test_state(); 414 + 415 + // Register a scripted query that returns a static response 416 + let lexicon = make_query_lexicon( 417 + "test.echo", 418 + Some(r#"function handle() return { greeting = "hello" } end"#), 419 + ); 420 + state.lexicons.upsert(lexicon).await; 421 + 422 + let mut params = HashMap::new(); 423 + let result = execute_local_query(&state, "test.echo", &mut params, None).await; 424 + assert!(result.is_ok(), "expected Ok, got: {:?}", result.err()); 425 + 426 + let response = result.unwrap(); 427 + assert_eq!(response.status(), 200); 428 + 429 + let body = response.into_body().collect().await.unwrap().to_bytes(); 430 + let json: Value = serde_json::from_slice(&body).unwrap(); 431 + assert_eq!(json["greeting"], "hello"); 432 + } 433 + 434 + #[tokio::test] 435 + async fn query_local_script_receives_params() { 436 + let state = test_state(); 437 + 438 + let lexicon = make_query_lexicon( 439 + "test.greet", 440 + Some(r#"function handle() return { greeting = "hello " .. params.name } end"#), 441 + ); 442 + state.lexicons.upsert(lexicon).await; 443 + 444 + let mut params = HashMap::new(); 445 + params.insert("name".into(), Value::String("world".into())); 446 + let result = execute_local_query(&state, "test.greet", &mut params, None).await; 447 + assert!(result.is_ok()); 448 + 449 + let body = result 450 + .unwrap() 451 + .into_body() 452 + .collect() 453 + .await 454 + .unwrap() 455 + .to_bytes(); 456 + let json: Value = serde_json::from_slice(&body).unwrap(); 457 + assert_eq!(json["greeting"], "hello world"); 458 + } 459 + 460 + #[tokio::test] 461 + async fn query_local_script_receives_caller_did() { 462 + let state = test_state(); 463 + 464 + let lexicon = make_query_lexicon( 465 + "test.whoami", 466 + Some( 467 + r#"function handle() 468 + return { did = caller_did or "anonymous" } 469 + end"#, 470 + ), 471 + ); 472 + state.lexicons.upsert(lexicon).await; 473 + 474 + // With caller_did 475 + let claims = Claims::internal("did:plc:testuser".into()); 476 + let mut params = HashMap::new(); 477 + let result = execute_local_query(&state, "test.whoami", &mut params, Some(&claims)).await; 478 + assert!(result.is_ok()); 479 + let body = result 480 + .unwrap() 481 + .into_body() 482 + .collect() 483 + .await 484 + .unwrap() 485 + .to_bytes(); 486 + let json: Value = serde_json::from_slice(&body).unwrap(); 487 + assert_eq!(json["did"], "did:plc:testuser"); 488 + 489 + // Without caller_did 490 + let mut params = HashMap::new(); 491 + let result = execute_local_query(&state, "test.whoami", &mut params, None).await; 492 + assert!(result.is_ok()); 493 + let body = result 494 + .unwrap() 495 + .into_body() 496 + .collect() 497 + .await 498 + .unwrap() 499 + .to_bytes(); 500 + let json: Value = serde_json::from_slice(&body).unwrap(); 501 + assert_eq!(json["did"], "anonymous"); 502 + } 503 + 504 + #[tokio::test] 505 + async fn query_rejects_procedure_lexicon() { 506 + let state = test_state(); 507 + 508 + let lexicon = make_procedure_lexicon("test.create", None); 509 + state.lexicons.upsert(lexicon).await; 510 + 511 + let mut params = HashMap::new(); 512 + let result = execute_local_query(&state, "test.create", &mut params, None).await; 513 + assert!(result.is_err()); 514 + let err = format!("{:?}", result.unwrap_err()); 515 + assert!(err.contains("not a query endpoint"), "got: {err}"); 516 + } 517 + 518 + #[tokio::test] 519 + async fn procedure_rejects_query_lexicon() { 520 + let state = test_state(); 521 + 522 + let lexicon = make_query_lexicon("test.echo", Some("function handle() end")); 523 + state.lexicons.upsert(lexicon).await; 524 + 525 + let claims = Claims::internal("did:plc:test".into()); 526 + let mut params = HashMap::new(); 527 + let result = 528 + execute_local_procedure(&state, "test.echo", &claims, &json!({}), &mut params).await; 529 + assert!(result.is_err()); 530 + let err = format!("{:?}", result.unwrap_err()); 531 + assert!(err.contains("not a procedure endpoint"), "got: {err}"); 532 + } 533 + 534 + // ----------------------------------------------------------------------- 535 + // Lua integration: xrpc.query from within a script 536 + // ----------------------------------------------------------------------- 537 + 538 + #[tokio::test] 539 + async fn lua_script_calls_xrpc_query() { 540 + let state = test_state(); 541 + 542 + // Register a simple query that the outer script will call 543 + let inner_lexicon = make_query_lexicon( 544 + "test.inner", 545 + Some(r#"function handle() return { value = 42 } end"#), 546 + ); 547 + state.lexicons.upsert(inner_lexicon).await; 548 + 549 + let state_arc = Arc::new(state); 550 + let lua = sandbox::create_sandbox().unwrap(); 551 + 552 + register_xrpc_api(&lua, state_arc, None).unwrap(); 553 + 554 + // Script that calls xrpc.query and parses the result 555 + lua.load( 556 + r#" 557 + function handle() 558 + local resp = xrpc.query("test.inner") 559 + local data = json.decode(resp.body) 560 + return { status = resp.status, inner_value = data.value } 561 + end 562 + "#, 563 + ) 564 + .exec() 565 + .unwrap(); 566 + 567 + // Register json global for the script 568 + let json_table = lua.create_table().unwrap(); 569 + let decode = lua 570 + .create_function(|lua, s: String| { 571 + let val: Value = serde_json::from_str(&s) 572 + .map_err(|e| mlua::Error::runtime(format!("json decode: {e}")))?; 573 + lua.to_value(&val) 574 + }) 575 + .unwrap(); 576 + json_table.set("decode", decode).unwrap(); 577 + lua.globals().set("json", json_table).unwrap(); 578 + 579 + let handle: mlua::Function = lua.globals().get("handle").unwrap(); 580 + let result: mlua::Value = handle.call_async(()).await.unwrap(); 581 + let json_result: Value = lua.from_value(result).unwrap(); 582 + 583 + assert_eq!(json_result["status"], 200); 584 + assert_eq!(json_result["inner_value"], 42); 585 + } 586 + 587 + // ----------------------------------------------------------------------- 588 + // Lua integration: xrpc.procedure requires caller_did 589 + // ----------------------------------------------------------------------- 590 + 591 + #[tokio::test] 592 + async fn lua_xrpc_procedure_fails_without_caller_did() { 593 + let state = test_state(); 594 + let state_arc = Arc::new(state); 595 + let lua = sandbox::create_sandbox().unwrap(); 596 + 597 + // Register with no caller_did 598 + register_xrpc_api(&lua, state_arc, None).unwrap(); 599 + 600 + lua.load( 601 + r#" 602 + function handle() 603 + return xrpc.procedure("test.something", {}) 604 + end 605 + "#, 606 + ) 607 + .exec() 608 + .unwrap(); 609 + 610 + let handle: mlua::Function = lua.globals().get("handle").unwrap(); 611 + let result: Result<mlua::Value, _> = handle.call_async(()).await; 612 + assert!(result.is_err()); 613 + let err = result.unwrap_err().to_string(); 614 + assert!( 615 + err.contains("requires authentication"), 616 + "expected auth error, got: {err}" 617 + ); 618 + } 619 + 620 + // ----------------------------------------------------------------------- 621 + // response_to_lua_table 622 + // ----------------------------------------------------------------------- 623 + 624 + #[tokio::test] 625 + async fn response_to_lua_table_converts_correctly() { 626 + let lua = sandbox::create_sandbox().unwrap(); 627 + let response = axum::response::Response::builder() 628 + .status(200) 629 + .body(axum::body::Body::from(r#"{"ok":true}"#)) 630 + .unwrap(); 631 + 632 + let table = response_to_lua_table(&lua, response).await.unwrap(); 633 + assert_eq!(table.get::<u16>("status").unwrap(), 200); 634 + assert_eq!(table.get::<String>("body").unwrap(), r#"{"ok":true}"#); 635 + } 636 + 637 + #[tokio::test] 638 + async fn response_to_lua_table_preserves_error_status() { 639 + let lua = sandbox::create_sandbox().unwrap(); 640 + let response = axum::response::Response::builder() 641 + .status(404) 642 + .body(axum::body::Body::from("not found")) 643 + .unwrap(); 644 + 645 + let table = response_to_lua_table(&lua, response).await.unwrap(); 646 + assert_eq!(table.get::<u16>("status").unwrap(), 404); 647 + assert_eq!(table.get::<String>("body").unwrap(), "not found"); 648 + } 649 + 650 + // ----------------------------------------------------------------------- 651 + // Claims::internal 652 + // ----------------------------------------------------------------------- 653 + 654 + #[test] 655 + fn claims_internal_has_no_client_key() { 656 + let claims = Claims::internal("did:plc:test".into()); 657 + assert_eq!(claims.did(), "did:plc:test"); 658 + assert!(claims.client_key().is_none()); 659 + } 660 + }
+5 -5
src/xrpc/mod.rs
··· 1 - mod procedure; 2 - mod query; 1 + pub(crate) mod procedure; 2 + pub(crate) mod query; 3 3 4 4 use axum::Json; 5 5 use axum::body::Body; ··· 19 19 20 20 /// Parse a raw query string into a map where repeated keys become JSON arrays. 21 21 /// Single-value keys remain as JSON strings for backward compatibility. 22 - fn parse_query_params(query: &str) -> HashMap<String, Value> { 22 + pub(crate) fn parse_query_params(query: &str) -> HashMap<String, Value> { 23 23 let mut multi: HashMap<String, Vec<String>> = HashMap::new(); 24 24 for pair in query.split('&') { 25 25 if pair.is_empty() { ··· 54 54 /// HTTP query params arrive as strings. Without this, Lua scripts receive 55 55 /// `"25"` (a string) for `params.limit`, which Postgres rejects when used 56 56 /// in LIMIT (`argument of LIMIT must be type bigint, not type text`). 57 - fn coerce_params(params: &mut HashMap<String, Value>, parameters: &Value) { 57 + pub(crate) fn coerce_params(params: &mut HashMap<String, Value>, parameters: &Value) { 58 58 let properties = match parameters.get("properties").and_then(|p| p.as_object()) { 59 59 Some(p) => p, 60 60 None => return, ··· 98 98 } 99 99 100 100 /// Proxy an unrecognized XRPC method to its home AppView resolved via DNS. 101 - async fn proxy_to_authority( 101 + pub(crate) async fn proxy_to_authority( 102 102 state: &AppState, 103 103 method: &str, 104 104 query_string: &str,
+1 -1
src/xrpc/procedure.rs
··· 10 10 use crate::record_refs::sync_refs; 11 11 use crate::repo; 12 12 13 - pub(super) async fn handle_procedure( 13 + pub(crate) async fn handle_procedure( 14 14 state: &AppState, 15 15 method: &str, 16 16 claims: &Claims,
+1 -1
src/xrpc/query.rs
··· 8 8 use crate::db::adapt_sql; 9 9 use crate::error::AppError; 10 10 11 - pub(super) async fn handle_query( 11 + pub(crate) async fn handle_query( 12 12 state: &AppState, 13 13 method: &str, 14 14 params: &HashMap<String, Value>,