A local-first private AI assistant for everyday use. Runs on-device models with encrypted P2P sync, and supports sharing chats publicly on ATProto.
10
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: Added /share command that generates a link which has session in PDS

madclaws 9a318296 d3c7a65a

+263 -138
+142 -38
tiles/src/core/account/atproto.rs
··· 1 1 //! Handles atprotocol stuff 2 2 3 3 use anyhow::{Result, anyhow}; 4 - use atrium_api::types::string::Did; 5 - use atrium_common::store::Store; 4 + use atrium_api::{ 5 + agent::Agent, 6 + types::{ 7 + Unknown, 8 + string::{Datetime, Did}, 9 + }, 10 + }; 11 + use atrium_common::store::{Store, memory::MemoryStore}; 6 12 use atrium_identity::{ 7 13 did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}, 8 14 handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver}, ··· 10 16 use atrium_oauth::{ 11 17 AtprotoLocalhostClientMetadata, AuthorizeOptions, CallbackParams, DefaultHttpClient, 12 18 KnownScope, OAuthClient, OAuthClientConfig, OAuthResolverConfig, Scope, 13 - store::{session::MemorySessionStore, state::MemoryStateStore}, 19 + store::{ 20 + session::{MemorySessionStore, Session}, 21 + state::{InternalStateData, MemoryStateStore}, 22 + }, 14 23 }; 15 24 use log::info; 16 25 use reqwest::Client; 17 26 use rusqlite::{Connection, OptionalExtension, params}; 18 27 use serde::{Deserialize, Serialize}; 28 + use serde_json::json; 19 29 use std::{fmt::Debug, process::Command, sync::Arc, time::Duration}; 20 30 use tokio::sync::oneshot; 21 31 ··· 23 33 24 34 use hickory_resolver::TokioResolver; 25 35 26 - use crate::{core::storage::db::Dbconn, daemon::start_internal_server, utils::get_unix_time_now}; 36 + use crate::{ 37 + core::storage::db::Dbconn, daemon::start_internal_server, runtime::mlx::SharedSession, 38 + utils::get_unix_time_now, 39 + }; 40 + 41 + // TODO: Make this dynamic porting 42 + const LOGIN_PORT: u32 = 8988; 43 + type TOAuthClient = OAuthClient< 44 + MemoryStore<String, InternalStateData>, 45 + MemoryStore<Did, Session>, 46 + CommonDidResolver<DefaultHttpClient>, 47 + AtprotoHandleResolver<HickoryDnsTxtResolver, DefaultHttpClient>, 48 + >; 27 49 28 50 #[derive(Deserialize)] 29 51 struct HandleResolve { ··· 86 108 } 87 109 88 110 pub async fn login(conn: &Dbconn, handle: &str) -> Result<()> { 89 - let http_client = Arc::new(DefaultHttpClient::default()); 90 - const LOGIN_PORT: u32 = 8988; 91 - 92 - let mem_session_store = MemorySessionStore::default(); 93 - let mem_state_store = MemoryStateStore::default(); 94 - 95 - let config = OAuthClientConfig { 96 - client_metadata: AtprotoLocalhostClientMetadata { 97 - redirect_uris: Some(vec![String::from("http://127.0.0.1:8988/callback")]), 98 - scopes: Some(vec![ 99 - Scope::Known(KnownScope::Atproto), 100 - Scope::Known(KnownScope::TransitionGeneric), 101 - ]), 102 - }, 103 - keys: None, 104 - resolver: OAuthResolverConfig { 105 - did_resolver: CommonDidResolver::new(CommonDidResolverConfig { 106 - plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), 107 - http_client: http_client.clone(), 108 - }), 109 - handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 110 - dns_txt_resolver: HickoryDnsTxtResolver::default(), 111 - http_client: http_client.clone(), 112 - }), 113 - authorization_server_metadata: Default::default(), 114 - protected_resource_metadata: Default::default(), 115 - }, 116 - state_store: mem_state_store.clone(), 117 - session_store: mem_session_store.clone(), 118 - }; 119 - 120 - let Ok(client) = OAuthClient::new(config) else { 121 - panic!("client fuck up") 122 - }; 111 + let (client, mem_session_store) = create_oauth_client()?; 123 112 124 113 //TODO: This resolve function is hack to convert handle to DID 125 114 // cuz for some reason the authorize fn not working for customd domains ··· 297 286 }) 298 287 }, 299 288 ).optional().map_err(Into::<anyhow::Error>::into) 289 + } 290 + 291 + //TODO: Move the login check to common fn 292 + pub async fn share_session(conn: &Connection, shared_session: SharedSession) -> Result<()> { 293 + if let Some(auth_data) = fetch_logged_in_data(&conn)? { 294 + let (client, mem_session_store) = create_oauth_client()?; 295 + let session: Session = serde_json::from_str(&auth_data.session)?; 296 + let did_struct = 297 + Did::new(auth_data.key.clone()).map_err(|_e| anyhow!("Failed to convert to Did"))?; 298 + 299 + mem_session_store.set(did_struct.clone(), session).await?; 300 + 301 + //TODO: Add a user friendly err latta 302 + let oauth_session = client.restore(&did_struct).await?; 303 + let agent = Agent::new(oauth_session); 304 + 305 + // let test_record = json!({ 306 + // "$type": "run.tiles.session", 307 + // "session_id": "019dd050-f337-7507-a8bc-b5eaf3547cc5", 308 + // "name": "dummy_session", 309 + // "contents": [ 310 + // { 311 + // "role": "user", 312 + // "content": "dummy content" 313 + // } 314 + // ], 315 + // "created_at": Datetime::now().as_str() 316 + // }); 317 + 318 + let shared_session_value = serde_json::to_value(shared_session)?; 319 + let record: Unknown = serde_json::from_value(shared_session_value)?; 320 + 321 + //TODO: can we remove the unwrap at collection 322 + let create_result = agent 323 + .api 324 + .com 325 + .atproto 326 + .repo 327 + .create_record( 328 + atrium_api::com::atproto::repo::create_record::InputData { 329 + collection: "run.tiles.session".parse().unwrap(), 330 + repo: did_struct.clone().into(), 331 + rkey: None, 332 + record, 333 + swap_commit: None, 334 + validate: None, 335 + } 336 + .into(), 337 + ) 338 + .await?; 339 + 340 + let url = &create_result.uri; 341 + 342 + let base_encoded_at_url = data_encoding::BASE64.encode(url.as_bytes()); 343 + 344 + let shareable_url = format!("https://tiles.run/share/{}", base_encoded_at_url); 345 + println!("successfully posted at {}", shareable_url); 346 + 347 + // Updating the session token 348 + let session = mem_session_store 349 + .get(&did_struct) 350 + .await? 351 + .expect("Expected Session"); 352 + let session_string = serde_json::to_string(&session)?; 353 + 354 + let auth_data = AtprotoAuthData { 355 + key: did_struct.to_string(), 356 + session: session_string, 357 + state: "".to_owned(), 358 + is_logged_in: true, 359 + created_at: get_unix_time_now(), 360 + updated_at: get_unix_time_now(), 361 + handle: auth_data.handle, 362 + }; 363 + 364 + upsert_auth_data(&conn, &auth_data)?; 365 + } else { 366 + println!("No logged-in user, please login") 367 + } 368 + Ok(()) 369 + } 370 + 371 + fn create_oauth_client() -> Result<(TOAuthClient, MemorySessionStore)> { 372 + let http_client = Arc::new(DefaultHttpClient::default()); 373 + 374 + let mem_session_store = MemorySessionStore::default(); 375 + let mem_state_store = MemoryStateStore::default(); 376 + 377 + let config = OAuthClientConfig { 378 + client_metadata: AtprotoLocalhostClientMetadata { 379 + redirect_uris: Some(vec![String::from("http://127.0.0.1:8988/callback")]), 380 + scopes: Some(vec![ 381 + Scope::Known(KnownScope::Atproto), 382 + Scope::Known(KnownScope::TransitionGeneric), 383 + ]), 384 + }, 385 + keys: None, 386 + resolver: OAuthResolverConfig { 387 + did_resolver: CommonDidResolver::new(CommonDidResolverConfig { 388 + plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), 389 + http_client: http_client.clone(), 390 + }), 391 + handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 392 + dns_txt_resolver: HickoryDnsTxtResolver::default(), 393 + http_client: http_client.clone(), 394 + }), 395 + authorization_server_metadata: Default::default(), 396 + protected_resource_metadata: Default::default(), 397 + }, 398 + state_store: mem_state_store.clone(), 399 + session_store: mem_session_store.clone(), 400 + }; 401 + 402 + let client = OAuthClient::new(config).map_err(Into::<anyhow::Error>::into)?; 403 + Ok((client, mem_session_store)) 300 404 } 301 405 302 406 #[cfg(test)]
+35 -9
tiles/src/core/chats.rs
··· 41 41 #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] 42 42 pub struct Chats { 43 43 pub id: String, 44 - content: String, 44 + pub content: String, 45 45 // The id of the responses api obj 46 46 response_id: Option<String>, 47 47 // The Model chat user role 48 - role: Role, 48 + pub role: Role, 49 49 user_id: String, 50 50 // The parent Id of a model's reply 51 51 context_id: Option<String>, ··· 58 58 #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] 59 59 pub struct Session { 60 60 pub id: String, 61 - name: String, 62 - created_at: u64, 61 + pub name: String, 62 + pub created_at: u64, 63 63 creator_id: String, 64 64 } 65 65 ··· 135 135 Err(err) => Err(<rusqlite::Error as Into<anyhow::Error>>::into(err)), 136 136 } 137 137 } 138 + 138 139 /// Return a Delta of chats and sessions for the given `user_id` since `last_row_counter` 139 140 pub fn get_delta(conn: &Connection, user_id: &str, last_row_couter: i64) -> Result<DeltaChat> { 140 - let mut session_map: HashMap<String, Session> = HashMap::new(); 141 + let query = "select id, user_id, content, resp_id, role, context_id, created_at, updated_at , row_counter, session_id from chats where user_id = ?1 and row_counter > ?2 order by id"; 141 142 142 - let mut stmt = conn.prepare("select id, user_id, content, resp_id, role, context_id, created_at, updated_at , row_counter, session_id from chats where user_id = ?1 and row_counter > ?2 order by id")?; 143 + let lrc_str = last_row_couter.to_string(); 144 + 145 + let params = vec![("?1", user_id), ("?2", &lrc_str)]; 146 + fetch_delta_chats(conn, query, params) 147 + } 143 148 144 - let chat_rows = stmt.query_map(params![user_id, last_row_couter], |row| { 149 + fn fetch_delta_chats( 150 + conn: &Connection, 151 + query: &str, 152 + params: Vec<(&str, &str)>, 153 + ) -> Result<DeltaChat> { 154 + let mut stmt = conn.prepare(query)?; 155 + 156 + let mut session_map: HashMap<String, Session> = HashMap::new(); 157 + let chat_rows = stmt.query_map(params.as_slice(), |row| { 145 158 let id: String = row.get(0)?; 146 159 let role: String = row.get(4)?; 147 160 let created_at: f64 = row.get(6)?; ··· 189 202 190 203 Ok(DeltaChat { chats, sessions }) 191 204 } 192 - 193 205 pub fn apply_delta(chat_conn: &mut Connection, delta_chats: DeltaChat) -> Result<()> { 194 206 // TODO: Handle primary key conflict, for now reject it (in a way its impossible to have this scenario, and if its occuring then that means 195 207 // some issue in syncing, so ignore it, by rejecting it), later ··· 348 360 } 349 361 } 350 362 351 - fn fetch_session(conn: &Connection, session_id: &str) -> Result<Session> { 363 + pub fn fetch_session(conn: &Connection, session_id: &str) -> Result<Session> { 352 364 let sesh = conn.query_row( 353 365 "SELECT id, name, creator_id, created_at FROM sessions WHERE id = ?1", 354 366 [session_id], ··· 369 381 370 382 fn decode_delta_from_bytes(bytes: &[u8]) -> Result<DeltaChat> { 371 383 postcard::from_bytes(bytes).map_err(Into::into) 384 + } 385 + 386 + pub fn fetch_chats_by_session_id(conn: &Connection, session_id: &str) -> Result<DeltaChat> { 387 + let query = "select id, user_id, content, resp_id, role, context_id, created_at, updated_at , row_counter, session_id from chats where session_id = ?1 order by id"; 388 + 389 + let params = vec![("?1", session_id)]; 390 + 391 + fetch_delta_chats(conn, query, params) 372 392 } 373 393 374 394 #[cfg(test)] ··· 1063 1083 creator_id TEXT NOT NULL, 1064 1084 created_at INTEGER NOT NULL 1065 1085 )", 1086 + [], 1087 + ) 1088 + .unwrap(); 1089 + 1090 + conn.execute( 1091 + "CREATE INDEX idx_chats_session_id ON chats(session_id);", 1066 1092 [], 1067 1093 ) 1068 1094 .unwrap();
+1
tiles/src/core/storage/db.rs
··· 90 90 created_at INTEGER NOT NULL 91 91 )", 92 92 ), 93 + M::up("CREATE INDEX idx_chats_session_id ON chats(session_id);"), 93 94 ]; 94 95 95 96 const CHATS_MIGRATIONS: Migrations = Migrations::from_slice(CHATS_MIGRATION_ARRAY);
+1
tiles/src/main.rs
··· 320 320 login(&db_conn, &handle).await?; 321 321 } 322 322 AtCommands::Logout => logout(&db_conn)?, 323 + // AtCommands::Share => share_session(&db_conn).await?, 323 324 }, 324 325 } 325 326 Ok(())
+84 -91
tiles/src/runtime/mlx.rs
··· 1 + use crate::core::account::atproto::share_session; 1 2 use crate::core::account::local::get_current_user; 2 - use crate::core::chats::{Message, create_session, save_chat}; 3 + use crate::core::chats::{ 4 + self, Message, create_session, fetch_chats_by_session_id, fetch_session, save_chat, 5 + }; 3 6 use crate::core::storage::db::Dbconn; 4 7 use crate::runtime::RunArgs; 5 8 use crate::utils::config::{ ··· 7 10 }; 8 11 use crate::utils::hf_model_downloader::*; 9 12 use anyhow::{Context, Result, anyhow}; 13 + use atrium_api::types::string::Datetime; 10 14 use log::info; 11 15 use reqwest::{Client, StatusCode}; 16 + use rusqlite::Connection; 12 17 use rustyline::completion::Completer; 13 18 use rustyline::highlight::Highlighter; 14 19 use rustyline::hint::Hinter; ··· 262 267 enum CommandType { 263 268 #[serde(rename = "get_state")] 264 269 State, 270 + #[serde(rename = "share")] 271 + Share, 265 272 #[serde(other)] 266 273 Unknown, 267 274 } 275 + 276 + #[derive(Serialize, Deserialize, Debug)] 277 + pub struct SharedSession { 278 + #[serde(rename = "$type")] 279 + r#type: String, 280 + session_id: String, 281 + name: String, 282 + contents: Vec<SharedContent>, 283 + created_at: String, 284 + } 285 + 286 + #[derive(Serialize, Deserialize, Debug)] 287 + pub struct SharedContent { 288 + role: Role, 289 + content: String, 290 + } 291 + 268 292 fn handle_input(input: &str, modelname: &str) -> InputType { 269 293 if let Some(cmd) = input.strip_prefix('/') { 270 294 match cmd { ··· 338 362 let config = Config::builder().auto_add_history(true).build(); 339 363 let mut editor = Editor::<TilesHinter, DefaultHistory>::with_config(config).unwrap(); 340 364 editor.set_helper(Some(TilesHinter)); 341 - // let mut g_reply: String = "".to_owned(); 342 - // let mut prev_response_id: String = String::from(""); 343 - 344 - // let mut conversations: Vec<Message> = vec![]; 345 365 346 366 let mut pi_process = start_pi_rpc(&modelname)?; 347 367 let mut session_id = String::new(); ··· 413 433 send_to_pi(pi_stdin, payload)?; 414 434 } 415 435 InputType::Command(cmd) => { 436 + let args: Vec<&str> = cmd.split(" ").collect(); 416 437 let cmd_json = json!(cmd); 438 + // println!("{}", cmd_json.to_string()); 417 439 let command: CommandType = serde_json::from_value(cmd_json)?; 418 440 match command { 419 441 CommandType::Unknown => { ··· 423 445 ); 424 446 continue; 425 447 } 448 + CommandType::Share => { 449 + process_share_session(&db_conn, &session_id, &args).await?; 450 + continue; 451 + } 426 452 cmd_type => { 427 453 let payload = get_command_payload(cmd_type); 428 454 send_to_pi(pi_stdin, payload) ··· 432 458 } 433 459 } 434 460 435 - // let mut bench_metrics: BenchmarkMetrics = BenchmarkMetrics { 436 - // ttft_ms: 0.0, 437 - // total_tokens: 0, 438 - // tokens_per_second: 0.0, 439 - // total_latency_s: 0.0, 440 - // }; 441 461 let reader = BufReader::new(&mut stdout); 442 462 let mut session_turn_count = 0; 443 463 let mut last_chat_id: String = "".to_owned(); ··· 524 544 } 525 545 } 526 546 } 527 - // loop { 528 - // if remaining_count > 0 { 529 - // let chat_start = remaining_count == run_args.relay_count; 530 - 531 - // match chat( 532 - // &input, 533 - // modelfile, 534 - // chat_start, 535 - // &python_code, 536 - // &g_reply, 537 - // run_args, 538 - // &prev_response_id, 539 - // &db_conn.chat, 540 - // &current_user, 541 - // &conversations, 542 - // ) 543 - // .await 544 - // { 545 - // Ok(response) => { 546 - // if response.reply.is_empty() { 547 - // if !response.code.is_empty() { 548 - // python_code = response.code; 549 - // } 550 - // if let Some(metrics) = response.metrics { 551 - // bench_metrics.update(metrics); 552 - // } 553 - // remaining_count -= 1; 554 - // } else { 555 - // g_reply = response.reply.clone(); 556 - // if run_args.memory { 557 - // println!("\n{}", response.reply.trim()); 558 - // } else { 559 - // prev_response_id = response.prev_response_id.clone(); 560 - // println!("\n"); 561 - // } 562 - // conversations.push(Message { 563 - // r#type: String::from("message"), 564 - // role: Role::User, 565 - // content: input, 566 - // }); 567 - // conversations.push(Message { 568 - // r#type: String::from("message"), 569 - // role: Role::Assistant, 570 - // content: g_reply.clone(), 571 - // }); 572 - 573 - // save_chat(&db_conn.chat, &current_user, &g_reply, Some(&response))?; 574 - // // Display benchmark metrics if available 575 - // if let Some(metrics) = response.metrics { 576 - // bench_metrics.update(metrics); 577 - // println!( 578 - // "{}", 579 - // format!( 580 - // "\n{} {:.1} tok/s | {} tokens | {:.0}s TTFT", 581 - // "💡".yellow(), 582 - // bench_metrics.total_tokens as f64 583 - // / bench_metrics.total_latency_s, 584 - // bench_metrics.total_tokens, 585 - // bench_metrics.ttft_ms / 1000.0 586 - // ) 587 - // .dimmed() 588 - // ); 589 - // } 590 - 591 - // break; 592 - // } 593 - // } 594 - // Err(err) => { 595 - // // if out of relay count, then clear the global_reply and ready for next fresh prompt 596 - // println!("{:?}", err); 597 - // g_reply.clear(); 598 - // break; 599 - // } 600 - // } 601 - // } 602 - // } 603 - // if g_reply.is_empty() { 604 - // println!("\nNo reply, try another prompt"); 605 - // } 606 547 } 607 548 Ok(()) 608 549 } ··· 940 881 let pi_process = Command::new(pi_exec_path) 941 882 .arg("--mode") 942 883 .arg("rpc") 943 - .arg("--no-session") 884 + // .arg("--no-session") 944 885 .env("PI_CODING_AGENT_DIR", pi_agent_dir) 945 886 .env("PI_OFFLINE", "true") 946 887 .stdin(Stdio::piped()) ··· 971 912 "type": "get_state", 972 913 }) 973 914 } 915 + // catch-all cases are where prolly its not a Pi command 916 + _ => json!([]), 974 917 } 975 918 } 976 919 ··· 983 926 use std::io::Write; 984 927 std::io::stdout().flush().ok(); 985 928 } 929 + // catch-all cases are non-Pi commands 930 + _ => (), 986 931 } 987 932 Ok(()) 988 933 } 934 + 935 + async fn process_share_session( 936 + conn: &Dbconn, 937 + current_session_id: &str, 938 + args: &[&str], 939 + ) -> Result<()> { 940 + let args = if let Some((_main_command, sub_commands)) = args.split_first() { 941 + sub_commands 942 + } else { 943 + println!("Not a valid command"); 944 + return Ok(()); 945 + }; 946 + 947 + let session_id = if args.is_empty() { 948 + current_session_id 949 + } else { 950 + args[0] 951 + }; 952 + // fetch session and the chats for the session_id 953 + 954 + let delta_chats = fetch_chats_by_session_id(&conn.chat, session_id)?; 955 + 956 + if delta_chats.sessions.is_empty() { 957 + println!("Session {} not available", session_id); 958 + } 959 + 960 + let session = &delta_chats.sessions[0]; 961 + 962 + let mut shared_contents: Vec<SharedContent> = vec![]; 963 + for chat in delta_chats.chats { 964 + shared_contents.push(SharedContent { 965 + role: chat.role, 966 + content: chat.content, 967 + }); 968 + } 969 + 970 + let shared_sessions = SharedSession { 971 + r#type: "run.tiles.session".to_string(), 972 + session_id: session_id.to_string(), 973 + name: session.name.clone(), 974 + contents: shared_contents, 975 + created_at: Datetime::now().as_str().to_string(), 976 + }; 977 + 978 + share_session(&conn.common, shared_sessions).await?; 979 + // pass it to the atproto share_session fn 980 + Ok(()) 981 + }