A local-first private AI assistant for everyday use. Runs on-device models with encrypted P2P sync, and supports sharing chats publicly on ATProto.
10
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: Added sessions support

+546 -268
pi-darwin-arm64.tar.gz

This is a binary file and will not be displayed.

+294 -105
tiles/src/core/chats.rs
··· 10 10 use crate::runtime::mlx::ChatResponse; 11 11 use crate::utils::get_unix_time_now; 12 12 use anyhow::{Result, anyhow}; 13 - use log::info; 13 + use log::{info, warn}; 14 14 use rusqlite::types::FromSqlError; 15 15 use rusqlite::{Connection, params}; 16 16 use tilekit::modelfile::Role; ··· 51 51 created_at: u64, 52 52 updated_at: u64, 53 53 row_counter: i64, 54 + session_id: String, 55 + } 56 + 57 + #[derive(Debug, serde::Serialize, serde::Deserialize)] 58 + pub struct Session { 59 + pub id: String, 60 + name: String, 61 + created_at: u64, 62 + creator_id: String, 54 63 } 55 64 56 65 type Responder<T> = oneshot::Sender<T>; ··· 70 79 }, 71 80 } 72 81 73 - pub fn save_chat( 74 - conn: &Connection, 75 - user: &User, 76 - input: &str, 77 - chat_resp: Option<&ChatResponse>, 78 - ) -> Result<Chats> { 82 + pub fn save_chat(conn: &Connection, user: &User, chat_resp: ChatResponse) -> Result<Chats> { 79 83 let row_counter = get_last_row_counter(conn, &user.user_id)?; 80 - if let Some(chat_response) = chat_resp { 81 - let chat_resp_cloned = chat_response.clone(); 82 - 83 - let chat = Chats { 84 - id: Uuid::now_v7().to_string(), 85 - user_id: user.user_id.clone(), 86 - content: input.to_owned(), 87 - response_id: Some(chat_resp_cloned.prev_response_id), 88 - role: Role::Assistant, 89 - context_id: chat_resp_cloned.parent_chat_id, 90 - created_at: get_unix_time_now(), 91 - updated_at: get_unix_time_now(), 92 - row_counter: row_counter + 1, 93 - }; 94 - 95 - conn.execute("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", (&chat.id.to_string(), &chat.user_id, &chat.content, &chat.response_id, Into::<String>::into(chat.role), &chat.context_id, &chat.created_at.to_string(), &chat.updated_at.to_string(), &chat.row_counter))?; 96 - 97 - Ok(chat) 98 - } else { 99 - let chat = Chats { 100 - id: Uuid::now_v7().to_string(), 101 - user_id: user.user_id.clone(), 102 - content: input.to_owned(), 103 - response_id: None, 104 - role: Role::User, 105 - context_id: None, 106 - created_at: get_unix_time_now(), 107 - updated_at: get_unix_time_now(), 108 - row_counter: row_counter + 1, 109 - }; 84 + let chat = Chats { 85 + id: Uuid::now_v7().to_string(), 86 + user_id: user.user_id.clone(), 87 + content: chat_resp.input, 88 + response_id: None, 89 + role: chat_resp.role, 90 + context_id: chat_resp.parent_chat_id, 91 + created_at: get_unix_time_now(), 92 + updated_at: get_unix_time_now(), 93 + row_counter: row_counter + 1, 94 + session_id: chat_resp.session_id, 95 + }; 110 96 111 - conn.execute("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", (&chat.id, &chat.user_id, &chat.content, &chat.response_id, Into::<String>::into(chat.role), &chat.context_id, &chat.created_at.to_string(), &chat.updated_at.to_string(), &chat.row_counter))?; 97 + conn.execute("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter, session_id) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", (&chat.id.to_string(), &chat.user_id, &chat.content, &chat.response_id, Into::<String>::into(chat.role), &chat.context_id, &chat.created_at.to_string(), &chat.updated_at.to_string(), &chat.row_counter, &chat.session_id))?; 112 98 113 - Ok(chat) 114 - } 99 + Ok(chat) 115 100 } 116 101 117 102 /// Returns the `id` of the last entry of the given user_id ··· 145 130 } 146 131 /// Return list of rows for the given `user_id` since `last_row_counter` 147 132 pub fn get_delta(conn: &Connection, user_id: &str, last_row_couter: i64) -> Result<Vec<Chats>> { 148 - let mut stmt = conn.prepare("select id, user_id, content, resp_id, role, context_id, created_at, updated_at , row_counter from chats where user_id = ?1 and row_counter > ?2 order by id")?; 133 + let mut stmt = conn.prepare("select id, user_id, content, resp_id, role, context_id, created_at, updated_at , row_counter, session_id from chats where user_id = ?1 and row_counter > ?2 order by id")?; 149 134 150 135 let chat_rows = stmt.query_map(params![user_id, last_row_couter], |row| { 151 136 let id: String = row.get(0)?; ··· 164 149 created_at: created_at as u64, 165 150 updated_at: updated_at as u64, 166 151 row_counter: row.get(8)?, 152 + session_id: row.get(9)?, 167 153 }) 168 154 })?; 169 155 ··· 184 170 185 171 let txn = chat_conn.transaction()?; 186 172 { 187 - let mut stmt = txn.prepare("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)")?; 173 + let mut stmt = txn.prepare("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter, session_id) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)")?; 188 174 189 175 for chat in delta_chats { 190 176 match stmt.execute(params![ ··· 197 183 &chat.created_at.to_string(), 198 184 &chat.updated_at.to_string(), 199 185 &chat.row_counter, 186 + &chat.session_id 200 187 ]) { 201 188 Err(rusqlite::Error::SqliteFailure(_, Some(reason))) 202 189 if reason == "UNIQUE constraint failed: chats.id" => ··· 269 256 tx 270 257 } 271 258 259 + pub fn create_session(conn: &Connection, id: &str, name: &str, user_id: &str) -> Result<Session> { 260 + // log a warning if session already exists, and skip the conflict 261 + 262 + let mut stmt = conn.prepare( 263 + "insert into sessions(id, name, creator_id, created_at) values (?1, ?2, ?3, ?4)", 264 + )?; 265 + 266 + match stmt.execute(params![ 267 + id.to_owned(), 268 + name.to_owned(), 269 + user_id.to_owned(), 270 + get_unix_time_now() as f64 271 + ]) { 272 + Ok(_res) => { 273 + let sesh = fetch_session(conn, id)?; 274 + Ok(sesh) 275 + } 276 + Err(rusqlite::Error::SqliteFailure(_, Some(reason))) 277 + if reason == "UNIQUE constraint failed: sessions.id" => 278 + { 279 + warn!("Session entry already exists, skipping"); 280 + let sesh = fetch_session(conn, id)?; 281 + Ok(sesh) 282 + } 283 + Err(err) => Err(anyhow!("Err inserting due to {}", err)), 284 + } 285 + } 286 + 287 + fn fetch_session(conn: &Connection, session_id: &str) -> Result<Session> { 288 + let sesh = conn.query_row( 289 + "SELECT id, name, creator_id, created_at FROM sessions WHERE id = ?1", 290 + [session_id], 291 + |row| { 292 + Ok(Session { 293 + id: row.get(0)?, 294 + name: row.get(1)?, 295 + creator_id: row.get(2)?, 296 + created_at: row.get::<usize, f64>(3)? as u64, 297 + }) 298 + }, 299 + )?; 300 + Ok(sesh) 301 + } 272 302 fn encode_delta_to_bytes(delta_chats: &Vec<Chats>) -> Vec<u8> { 273 303 postcard::to_stdvec(delta_chats).expect("Failed to convert to bytes with postcard") 274 304 } ··· 290 320 core::{ 291 321 accounts::{ACCOUNT, User}, 292 322 chats::{ 293 - apply_delta, decode_delta_from_bytes, encode_delta_to_bytes, get_delta, 294 - get_last_row_counter, save_chat, 323 + apply_delta, create_session, decode_delta_from_bytes, encode_delta_to_bytes, 324 + get_delta, get_last_row_counter, save_chat, 295 325 }, 296 326 }, 297 327 runtime::mlx::ChatResponse, ··· 303 333 let conn = setup_db_schema(); 304 334 let user = create_user(); 305 335 let input = "2+2"; 306 - let chat = save_chat(&conn, &user, input, None).expect("chat should be saved"); 336 + 337 + let chat_response = ChatResponse { 338 + input: input.to_owned(), 339 + session_id: String::from("session_abc"), 340 + role: Role::User, 341 + code: None, 342 + prev_response_id: None, 343 + parent_chat_id: None, 344 + metrics: None, 345 + }; 346 + let chat = save_chat(&conn, &user, chat_response).expect("chat should be saved"); 307 347 308 348 assert_eq!(chat.user_id, user.user_id); 309 349 assert!(chat.response_id.is_none()); ··· 322 362 let conn = setup_db_schema(); 323 363 let user = create_user(); 324 364 let parent_chat_id = Uuid::now_v7().to_string(); 325 - let chat_resp = ChatResponse { 326 - reply: "reply".to_owned(), 327 - code: "code".to_owned(), 328 - prev_response_id: String::from("resp_prev"), 365 + let input = "2+2"; 366 + let chat_response = ChatResponse { 367 + input: input.to_owned(), 368 + session_id: String::from("session_abc"), 369 + role: Role::Assistant, 370 + code: None, 371 + prev_response_id: None, 329 372 parent_chat_id: Some(parent_chat_id.clone()), 330 373 metrics: None, 331 374 }; 332 - let input = "2+2"; 333 - let chat = save_chat(&conn, &user, input, Some(&chat_resp)).expect("chat should be saved"); 375 + let chat = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 334 376 335 377 assert_eq!(chat.user_id, user.user_id); 336 - assert_eq!(chat.response_id.as_deref(), Some("resp_prev")); 337 378 assert_eq!(chat.context_id, Some(parent_chat_id.clone())); 338 379 339 380 let saved = fetch_saved_chat_row(&conn, &chat.id); 340 381 assert_eq!(saved.content, input); 341 - assert_eq!(saved.resp_id, Some(String::from("resp_prev"))); 342 382 assert_eq!(saved.role, Into::<String>::into(Role::Assistant)); 343 383 assert_eq!(saved.user_id, user.user_id); 344 384 assert_eq!(saved.context_id, Some(parent_chat_id.clone())); ··· 348 388 fn test_response_without_parent_chat_id_saves_nil_context() { 349 389 let conn = setup_db_schema(); 350 390 let user = create_user(); 351 - let chat_resp = ChatResponse { 352 - reply: "reply".to_owned(), 353 - code: "code".to_owned(), 354 - prev_response_id: String::from("resp_prev"), 391 + let chat_response = ChatResponse { 392 + input: "".to_owned(), 393 + session_id: String::from("session_abc"), 394 + role: Role::Assistant, 395 + code: None, 396 + prev_response_id: Some(Uuid::now_v7().to_string()), 355 397 parent_chat_id: Some(Uuid::now_v7().to_string()), 356 398 metrics: None, 357 399 }; 358 400 359 - let chat = 360 - save_chat(&conn, &user, "hello", Some(&chat_resp)).expect("chat should be saved"); 401 + let chat = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 361 402 362 403 assert!(chat.context_id.is_some()); 363 404 let saved = fetch_saved_chat_row(&conn, &chat.id); ··· 369 410 fn test_empty_input_is_saved() { 370 411 let conn = setup_db_schema(); 371 412 let user = create_user(); 372 - 373 - let chat = save_chat(&conn, &user, "", None).expect("empty content should still be saved"); 413 + let chat_response = ChatResponse { 414 + input: "".to_owned(), 415 + session_id: String::from("session_abc"), 416 + role: Role::User, 417 + code: None, 418 + prev_response_id: None, 419 + parent_chat_id: None, 420 + metrics: None, 421 + }; 422 + let chat = 423 + save_chat(&conn, &user, chat_response).expect("empty content should still be saved"); 374 424 375 425 let saved = fetch_saved_chat_row(&conn, &chat.id); 376 426 assert_eq!(saved.content, ""); ··· 381 431 fn test_save_chat_errors_when_table_missing() { 382 432 let conn = Connection::open_in_memory().expect("in-memory db should open"); 383 433 let user = create_user(); 384 - 385 - let result = save_chat(&conn, &user, "2+2", None); 434 + let chat_response = ChatResponse { 435 + input: "".to_owned(), 436 + session_id: String::from("session_abc"), 437 + role: Role::User, 438 + code: None, 439 + prev_response_id: None, 440 + parent_chat_id: None, 441 + metrics: None, 442 + }; 443 + let result = save_chat(&conn, &user, chat_response); 386 444 387 445 assert!(result.is_err()); 388 446 } ··· 391 449 fn test_last_row_counter() { 392 450 let conn = setup_db_schema(); 393 451 let user = create_user(); 394 - let input = "2+2"; 395 - let chat = save_chat(&conn, &user, input, None).expect("chat should be saved"); 452 + let chat_response = ChatResponse { 453 + input: "".to_owned(), 454 + session_id: String::from("session_abc"), 455 + role: Role::User, 456 + code: None, 457 + prev_response_id: None, 458 + parent_chat_id: None, 459 + metrics: None, 460 + }; 461 + let chat = save_chat(&conn, &user, chat_response).expect("chat should be saved"); 396 462 397 463 assert_eq!(chat.user_id, user.user_id); 398 464 assert!(chat.response_id.is_none()); ··· 415 481 let conn = setup_db_schema(); 416 482 let user = create_user(); 417 483 let input = "2+2"; 418 - let chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved"); 419 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 420 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 421 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 484 + let chat_response = ChatResponse { 485 + input: input.to_owned(), 486 + session_id: String::from("session_abc"), 487 + role: Role::User, 488 + code: None, 489 + prev_response_id: None, 490 + parent_chat_id: None, 491 + metrics: None, 492 + }; 493 + let chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 494 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 495 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 496 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 422 497 423 498 let rows = get_delta(&conn, &user.user_id, chat_1.row_counter).unwrap(); 424 499 assert_eq!(rows.len(), 3); ··· 429 504 let conn = setup_db_schema(); 430 505 let user = create_user(); 431 506 let input = "2+2"; 432 - let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved"); 433 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 434 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 435 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 507 + let chat_response = ChatResponse { 508 + input: input.to_owned(), 509 + session_id: String::from("session_abc"), 510 + role: Role::User, 511 + code: None, 512 + prev_response_id: None, 513 + parent_chat_id: None, 514 + metrics: None, 515 + }; 516 + let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 517 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 518 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 519 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 436 520 437 521 let rows = get_delta(&conn, &user.user_id, 0).unwrap(); 438 522 assert_eq!(rows.len(), 4); ··· 443 527 let conn = setup_db_schema(); 444 528 let user = create_user(); 445 529 let input = "2+2"; 446 - let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved"); 447 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 448 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 449 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 530 + let chat_response = ChatResponse { 531 + input: input.to_owned(), 532 + session_id: String::from("session_abc"), 533 + role: Role::User, 534 + code: None, 535 + prev_response_id: None, 536 + parent_chat_id: None, 537 + metrics: None, 538 + }; 539 + let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 540 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 541 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 542 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 450 543 451 544 let rows = get_delta(&conn, "", 0).unwrap(); 452 545 assert_eq!(rows.len(), 0); ··· 458 551 let mut conn_2 = setup_db_schema(); 459 552 let user = create_user(); 460 553 let input = "2+2"; 461 - let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved"); 462 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 463 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 464 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 554 + let chat_response = ChatResponse { 555 + input: input.to_owned(), 556 + session_id: String::from("session_abc"), 557 + role: Role::User, 558 + code: None, 559 + prev_response_id: None, 560 + parent_chat_id: None, 561 + metrics: None, 562 + }; 563 + let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 564 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 565 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 566 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 465 567 466 568 let rows = get_delta(&conn, &user.user_id, 0).unwrap(); 467 569 assert_eq!(rows.len(), 4); ··· 476 578 let mut conn_2 = setup_db_schema(); 477 579 let user = create_user(); 478 580 let input = "2+2"; 479 - let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved"); 480 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 481 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 482 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 581 + let chat_response = ChatResponse { 582 + input: input.to_owned(), 583 + session_id: String::from("session_abc"), 584 + role: Role::User, 585 + code: None, 586 + prev_response_id: None, 587 + parent_chat_id: None, 588 + metrics: None, 589 + }; 590 + let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 591 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 592 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 593 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 483 594 484 595 let rows = get_delta(&conn, &user.user_id, 0).unwrap(); 485 596 assert_eq!(rows.len(), 4); ··· 496 607 let mut conn_2 = setup_db_schema(); 497 608 let user = create_user(); 498 609 let input = "2+2"; 499 - let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved"); 500 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 501 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 502 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 610 + let chat_response = ChatResponse { 611 + input: input.to_owned(), 612 + session_id: String::from("session_abc"), 613 + role: Role::User, 614 + code: None, 615 + prev_response_id: None, 616 + parent_chat_id: None, 617 + metrics: None, 618 + }; 619 + let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 620 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 621 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 622 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 503 623 504 624 let rows = get_delta(&conn, &user.user_id, 4).unwrap(); 505 625 assert_eq!(rows.len(), 0); ··· 516 636 let mut _conn_2 = setup_db_schema(); 517 637 let user = create_user(); 518 638 let input = "2+2"; 519 - let chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved"); 520 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 521 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 522 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 639 + let chat_response = ChatResponse { 640 + input: input.to_owned(), 641 + session_id: String::from("session_abc"), 642 + role: Role::User, 643 + code: None, 644 + prev_response_id: None, 645 + parent_chat_id: None, 646 + metrics: None, 647 + }; 648 + let chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 649 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 650 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 651 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 523 652 let rows = get_delta(&conn, &user.user_id, chat_1.row_counter).unwrap(); 524 653 assert_eq!(rows.len(), 3); 525 654 } ··· 531 660 let mut conn_2 = setup_db_schema(); 532 661 let user = create_user(); 533 662 let input = "2+2"; 534 - let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved"); 535 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 536 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 537 - let _ = save_chat(&conn, &user, input, None).expect("chat should be saved"); 663 + let chat_response = ChatResponse { 664 + input: input.to_owned(), 665 + session_id: String::from("session_abc"), 666 + role: Role::User, 667 + code: None, 668 + prev_response_id: None, 669 + parent_chat_id: None, 670 + metrics: None, 671 + }; 672 + let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 673 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 674 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 675 + let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved"); 538 676 539 677 let rows = get_delta(&conn, &user.user_id, 0).unwrap(); 540 678 assert_eq!(rows.len(), 4); ··· 558 696 559 697 // Node user A adds stuff 560 698 let input = "2+2"; 561 - let _chat_1 = save_chat(&conn, &user_a, input, None).expect("chat should be saved"); 562 - let _ = save_chat(&conn, &user_a, input, None).expect("chat should be saved"); 563 - let _ = save_chat(&conn, &user_a, input, None).expect("chat should be saved"); 564 - let _ = save_chat(&conn, &user_a, input, None).expect("chat should be saved"); 699 + let chat_response = ChatResponse { 700 + input: input.to_owned(), 701 + session_id: String::from("session_abc"), 702 + role: Role::User, 703 + code: None, 704 + prev_response_id: None, 705 + parent_chat_id: None, 706 + metrics: None, 707 + }; 708 + let _chat_1 = 709 + save_chat(&conn, &user_a, chat_response.clone()).expect("chat should be saved"); 710 + let _ = save_chat(&conn, &user_a, chat_response.clone()).expect("chat should be saved"); 711 + let _ = save_chat(&conn, &user_a, chat_response.clone()).expect("chat should be saved"); 712 + let _ = save_chat(&conn, &user_a, chat_response.clone()).expect("chat should be saved"); 565 713 566 714 // Node user B adds stuff 567 715 let input = "4+4"; 568 - let _chat_1 = save_chat(&conn_2, &user_b, input, None).expect("chat should be saved"); 569 - let _ = save_chat(&conn_2, &user_b, input, None).expect("chat should be saved"); 570 - let _ = save_chat(&conn_2, &user_b, input, None).expect("chat should be saved"); 571 - let _ = save_chat(&conn_2, &user_b, input, None).expect("chat should be saved"); 716 + let chat_response = ChatResponse { 717 + input: input.to_owned(), 718 + session_id: String::from("session_abc"), 719 + role: Role::User, 720 + code: None, 721 + prev_response_id: None, 722 + parent_chat_id: None, 723 + metrics: None, 724 + }; 725 + let _chat_1 = 726 + save_chat(&conn_2, &user_b, chat_response.clone()).expect("chat should be saved"); 727 + let _ = save_chat(&conn_2, &user_b, chat_response.clone()).expect("chat should be saved"); 728 + let _ = save_chat(&conn_2, &user_b, chat_response.clone()).expect("chat should be saved"); 729 + let _ = save_chat(&conn_2, &user_b, chat_response.clone()).expect("chat should be saved"); 572 730 573 731 // Node A wants to sync with Node B 574 732 ··· 638 796 assert_eq!(user_a_rows, user_b_rows); 639 797 } 640 798 799 + #[test] 800 + fn test_valid_input_create_session() { 801 + let conn = setup_db_schema(); 802 + let user = create_user(); 803 + 804 + let session = create_session(&conn, "id-1", "sesh-1", &user.user_id).unwrap(); 805 + assert_eq!(user.user_id, session.creator_id); 806 + } 807 + 808 + #[test] 809 + fn test_duplicate_id_create_session() { 810 + let conn = setup_db_schema(); 811 + let user = create_user(); 812 + 813 + let session = create_session(&conn, "id-1", "sesh-1", &user.user_id).unwrap(); 814 + assert_eq!(user.user_id, session.creator_id); 815 + 816 + let session_2 = create_session(&conn, "id-1", "sesh-1", &user.user_id).unwrap(); 817 + assert_eq!(user.user_id, session_2.creator_id); 818 + } 819 + 641 820 struct SavedChatRow { 642 821 content: String, 643 822 resp_id: Option<String>, ··· 718 897 ) 719 898 .unwrap(); 720 899 900 + conn.execute( 901 + "CREATE TABLE IF NOT EXISTS sessions ( 902 + id TEXT PRIMARY KEY, 903 + name TEXT NOT NULL, 904 + creator_id TEXT NOT NULL, 905 + created_at INTEGER NOT NULL 906 + )", 907 + [], 908 + ) 909 + .unwrap(); 721 910 conn 722 911 } 723 912 }
+8
tiles/src/core/storage/db.rs
··· 70 70 ALTER TABLE CHATS ADD COLUMN session_id TEXT; 71 71 ", 72 72 ), 73 + M::up( 74 + "CREATE TABLE IF NOT EXISTS sessions ( 75 + id TEXT PRIMARY KEY, 76 + name TEXT NOT NULL, 77 + creator_id TEXT NOT NULL, 78 + created_at INTEGER NOT NULL 79 + )", 80 + ), 73 81 ]; 74 82 75 83 const CHATS_MIGRATIONS: Migrations = Migrations::from_slice(CHATS_MIGRATION_ARRAY);
+243 -162
tiles/src/runtime/mlx.rs
··· 1 1 use crate::core::accounts::{User, get_current_user}; 2 - use crate::core::chats::{Message, save_chat}; 2 + use crate::core::chats::{Message, create_session, save_chat}; 3 3 use crate::core::storage::db::Dbconn; 4 4 use crate::runtime::RunArgs; 5 - use crate::utils::config::{ 6 - ConfigProvider, DefaultProvider, get_memory_path, get_model_cache, update_current_model, 7 - }; 5 + use crate::utils::config::{ConfigProvider, DefaultProvider, get_memory_path, get_model_cache}; 8 6 use crate::utils::hf_model_downloader::*; 9 7 use anyhow::{Context, Result, anyhow}; 10 - use futures_util::StreamExt; 11 - use owo_colors::OwoColorize; 8 + use log::info; 12 9 use reqwest::{Client, StatusCode}; 13 - use rusqlite::Connection; 14 10 use rustyline::completion::Completer; 15 11 use rustyline::highlight::Highlighter; 16 12 use rustyline::hint::Hinter; ··· 19 15 use rustyline::{Config, Editor, Helper}; 20 16 use serde::{Deserialize, Serialize}; 21 17 use serde_json::{Value, json}; 22 - use std::cell::{Cell, RefCell}; 23 - use std::collections::HashMap; 24 18 use std::fs::OpenOptions; 25 - use std::io::{BufRead, BufReader, Read, Write}; 19 + use std::io::{BufRead, BufReader, Write}; 26 20 use std::path::PathBuf; 27 - use std::process::{Child, ChildStdout, Command}; 21 + use std::process::{Child, Command}; 28 22 use std::process::{ChildStdin, Stdio}; 29 - use std::rc::Rc; 30 23 use std::time::Duration; 31 24 use tilekit::modelfile::Modelfile; 32 25 use tilekit::modelfile::Role; ··· 59 52 60 53 #[derive(Clone, Debug)] 61 54 pub struct ChatResponse { 62 - // think: String, 63 - pub reply: String, 64 - pub code: String, 65 - pub prev_response_id: String, 55 + // text content 56 + pub input: String, 57 + pub session_id: String, 58 + pub role: Role, 59 + pub code: Option<String>, 60 + // deprecated, will remove soon 61 + pub prev_response_id: Option<String>, 66 62 pub parent_chat_id: Option<String>, 67 63 pub metrics: Option<BenchmarkMetrics>, 68 64 } ··· 78 74 MessageUpdate(PiMessageUpdate), 79 75 #[serde(rename = "agent_end")] 80 76 AgentEnd, 77 + #[serde(rename = "turn_end")] 78 + TurnEnd(PiTurnEndEvent), 81 79 #[serde[other]] 82 80 Unknown, 83 81 } ··· 110 108 command: CommandType, 111 109 success: bool, 112 110 data: Option<Value>, 111 + } 112 + 113 + #[derive(Serialize, Deserialize, Debug)] 114 + struct PiTurnEndEvent { 115 + message: PiTurnEndEventMsg, 116 + } 117 + 118 + #[derive(Serialize, Deserialize, Debug)] 119 + struct PiTurnEndEventMsg { 120 + role: String, 121 + content: Vec<PiMsgContent>, 122 + } 123 + 124 + #[derive(Serialize, Deserialize, Debug)] 125 + struct PiMsgContent { 126 + r#type: String, 127 + text: String, 113 128 } 114 129 115 130 impl Default for MLXRuntime { ··· 324 339 let mut conversations: Vec<Message> = vec![]; 325 340 326 341 let mut pi_process = start_pi_rpc()?; 327 - 342 + let mut session_id = String::new(); 328 343 let pi_stdin = pi_process.stdin.as_mut().unwrap(); 329 344 let mut stdout = pi_process.stdout.take().expect("stdout"); 330 - // let mut stdout: Cell<ChildStdout> = Cell::new(); 345 + let inti_cmd_payload = get_command_payload(CommandType::State); 346 + send_to_pi(pi_stdin, inti_cmd_payload).inspect_err(|_e| eprintln!("send pi failed"))?; 347 + 348 + //TODO: Refactor session_id fetching 349 + let mut pi_session_state = String::new(); 350 + let mut reader = BufReader::new(&mut stdout); 351 + let _ = reader 352 + .read_line(&mut pi_session_state) 353 + .context("Failed reading pi session state")?; 354 + println!("{}", pi_session_state); 355 + let response: PiResponse = serde_json::from_str(&pi_session_state)?; 356 + if let PiResponse::Response(msg) = response { 357 + let state: GetStateData = 358 + serde_json::from_value(msg.data.expect("get state parsing failed"))?; 359 + session_id = state.session_id; 360 + info!("Current session: {}", session_id); 361 + } 362 + 331 363 loop { 332 364 let readline = editor.readline(">>> "); 333 365 let input = match readline { 334 366 Ok(line) => line.trim().to_string(), 335 367 Err(_) => { 368 + //TODO: Panic when entering another prompt after ctr-l C 369 + // called `Result::unwrap()` on an `Err` value: Os { code: 32, kind: BrokenPipe, message: "Broken pipe" } 370 + // 336 371 // User pressed Ctrl+C or Ctrl+D 337 372 let end_payload = json!({ 338 373 "type": "abort", ··· 403 438 }; 404 439 let mut is_agent_streaming: bool = false; 405 440 let reader = BufReader::new(&mut stdout); 406 - 441 + let mut session_turn_count = 0; 442 + let mut last_chat_id: String = "".to_owned(); 407 443 for line in reader.lines() { 408 444 //TODO: handle the unwrap 409 445 let line = line?; ··· 411 447 412 448 match response { 413 449 PiResponse::AgentStart => { 414 - // agent streaming started 415 - is_agent_streaming = true 450 + info!("\nAgent start\n"); 416 451 } 417 452 PiResponse::MessageUpdate(msg_update) => { 418 453 if msg_update.assistant_message_event.r#type == "text_delta" 419 454 && msg_update.assistant_message_event.delta.is_some() 420 455 { 456 + // TODO: Can we remove the unwrap 421 457 print!("{}", msg_update.assistant_message_event.delta.unwrap()); 422 458 // TODO: maybe can optimize check print! doc 423 459 use std::io::Write; ··· 425 461 } 426 462 } 427 463 PiResponse::AgentEnd => { 428 - // agent streaming stopeed 429 - is_agent_streaming = false; 464 + info!("\nAgent End\n"); 430 465 break; 431 466 } 467 + PiResponse::TurnEnd(turn_event) => { 468 + info!("\nTurn end\n"); 469 + session_turn_count += 1; 470 + 471 + // on agent end create a new session entry, only for the 472 + // first time 473 + if session_turn_count == 1 { 474 + info!("Created session {}", session_id); 475 + create_session(&db_conn.chat, &session_id, "dummy", &current_user.user_id)?; 476 + } 477 + let parent_chat_id = if session_turn_count == 1 { 478 + None 479 + } else { 480 + Some(last_chat_id.clone()) 481 + }; 482 + let chat_response = ChatResponse { 483 + input: input.clone(), 484 + session_id: session_id.clone(), 485 + role: Role::User, 486 + code: None, 487 + prev_response_id: None, 488 + parent_chat_id, 489 + metrics: None, 490 + }; 491 + let prompt_chat = save_chat(&db_conn.chat, &current_user, chat_response)?; 492 + last_chat_id = prompt_chat.id; 493 + if turn_event.message.role == "assistant" { 494 + let mut content = turn_event.message.content; 495 + if let Some(msg) = content.pop() { 496 + let chat_response = ChatResponse { 497 + input: msg.text.clone(), 498 + session_id: session_id.clone(), 499 + role: Role::Assistant, 500 + code: None, 501 + prev_response_id: None, 502 + parent_chat_id: Some(last_chat_id.clone()), 503 + metrics: None, 504 + }; 505 + let chat = save_chat(&db_conn.chat, &current_user, chat_response)?; 506 + last_chat_id = chat.id; 507 + } 508 + } else { 509 + info!("Not handling {} role now", turn_event.message.role); 510 + } 511 + } 432 512 PiResponse::Response(response_msg) => { 433 513 if response_msg.success { 434 514 match response_msg.command { 435 515 CommandType::Unknown => { 516 + println!("{}", line); 436 517 continue; 437 518 } 438 519 cmd => process_command(cmd, response_msg.data)?, ··· 586 667 } 587 668 588 669 //TODO: Have 2 separate chat functions for memory and non-memory 589 - #[allow(clippy::too_many_arguments)] 590 - async fn chat( 591 - input: &str, 592 - modelfile: &Modelfile, 593 - chat_start: bool, 594 - python_code: &str, 595 - g_reply: &str, 596 - run_args: &RunArgs, 597 - prev_response_id: &str, 598 - conn: &Connection, 599 - user: &User, 600 - conversations: &[Message], 601 - ) -> Result<ChatResponse> { 602 - let client = Client::new(); 603 - let modelname = modelfile 604 - .from 605 - .clone() 606 - .ok_or_else(|| anyhow!("Failed to get model name"))?; 607 - let prompt = modelfile.system.clone().unwrap_or("".to_owned()); 608 - let convo_input = create_chat_input(input, prompt.as_str(), conversations); 609 - let body = json!({ 610 - "model": modelname, 611 - "input": convo_input, 612 - "reasoning": {"effort": "medium"}, 613 - "chat_start": chat_start, 614 - "stream": true, 615 - "previous_response_id": prev_response_id, 616 - "python_code": python_code, 617 - "messages": [{"role": "assistant", "content": g_reply}, {"role": "user", "content": input}] 618 - }); 670 + // #[allow(clippy::too_many_arguments)] 671 + // async fn chat( 672 + // input: &str, 673 + // modelfile: &Modelfile, 674 + // chat_start: bool, 675 + // python_code: &str, 676 + // g_reply: &str, 677 + // run_args: &RunArgs, 678 + // prev_response_id: &str, 679 + // conn: &Connection, 680 + // user: &User, 681 + // conversations: &[Message], 682 + // ) -> Result<ChatResponse> { 683 + // let client = Client::new(); 684 + // let modelname = modelfile 685 + // .from 686 + // .clone() 687 + // .ok_or_else(|| anyhow!("Failed to get model name"))?; 688 + // let prompt = modelfile.system.clone().unwrap_or("".to_owned()); 689 + // let convo_input = create_chat_input(input, prompt.as_str(), conversations); 690 + // let body = json!({ 691 + // "model": modelname, 692 + // "input": convo_input, 693 + // "reasoning": {"effort": "medium"}, 694 + // "chat_start": chat_start, 695 + // "stream": true, 696 + // "previous_response_id": prev_response_id, 697 + // "python_code": python_code, 698 + // "messages": [{"role": "assistant", "content": g_reply}, {"role": "user", "content": input}] 699 + // }); 619 700 620 - let memory_body = json!({ 621 - "model": modelname, 622 - "input": input, 623 - "chat_start": chat_start, 624 - "stream": true, 625 - "python_code": python_code, 626 - "messages": [{"role": "assistant", "content": g_reply}, {"role": "user", "content": input}] 701 + // let memory_body = json!({ 702 + // "model": modelname, 703 + // "input": input, 704 + // "chat_start": chat_start, 705 + // "stream": true, 706 + // "python_code": python_code, 707 + // "messages": [{"role": "assistant", "content": g_reply}, {"role": "user", "content": input}] 627 708 628 - }); 629 - let res = if run_args.memory { 630 - let api_url = "http://127.0.0.1:6969/v1/chat/completions"; 631 - client.post(api_url).json(&memory_body).send().await? 632 - } else { 633 - let api_url = "http://127.0.0.1:6969/v1/responses"; 634 - client.post(api_url).json(&body).send().await? 635 - }; 709 + // }); 710 + // let res = if run_args.memory { 711 + // let api_url = "http://127.0.0.1:6969/v1/chat/completions"; 712 + // client.post(api_url).json(&memory_body).send().await? 713 + // } else { 714 + // let api_url = "http://127.0.0.1:6969/v1/responses"; 715 + // client.post(api_url).json(&body).send().await? 716 + // }; 636 717 637 - let chat = save_chat(conn, user, input, None)?; 638 - let mut stream = res.bytes_stream(); 639 - let mut accumulated = String::new(); 640 - let mut metrics: Option<BenchmarkMetrics> = None; 641 - let mut is_answer_start = false; 642 - let mut prev_response_id: String = String::from(""); 643 - let mut output_completed: bool = false; 644 - while let Some(chunk) = stream.next().await { 645 - let chunk = chunk?; 646 - let s = String::from_utf8_lossy(&chunk); 647 - for line in s.lines() { 648 - if !line.starts_with("data: ") { 649 - continue; 650 - } 718 + // let chat = save_chat(conn, user, input, None)?; 719 + // let mut stream = res.bytes_stream(); 720 + // let mut accumulated = String::new(); 721 + // let mut metrics: Option<BenchmarkMetrics> = None; 722 + // let mut is_answer_start = false; 723 + // let mut prev_response_id: String = String::from(""); 724 + // let mut output_completed: bool = false; 725 + // while let Some(chunk) = stream.next().await { 726 + // let chunk = chunk?; 727 + // let s = String::from_utf8_lossy(&chunk); 728 + // for line in s.lines() { 729 + // if !line.starts_with("data: ") { 730 + // continue; 731 + // } 651 732 652 - let data = line.trim_start_matches("data: "); 733 + // let data = line.trim_start_matches("data: "); 653 734 654 - if data == "[DONE]" { 655 - let mut chat_resp = convert_to_chat_response( 656 - &accumulated, 657 - run_args.memory, 658 - prev_response_id, 659 - metrics, 660 - ); 661 - chat_resp.parent_chat_id = Some(chat.id); 662 - return Ok(chat_resp); 663 - } 735 + // if data == "[DONE]" { 736 + // let mut chat_resp = convert_to_chat_response( 737 + // &accumulated, 738 + // run_args.memory, 739 + // prev_response_id, 740 + // metrics, 741 + // ); 742 + // chat_resp.parent_chat_id = Some(chat.id); 743 + // return Ok(chat_resp); 744 + // } 664 745 665 - //TODO: This will break if we ask the model to give an essay and all 666 - let v: Value = serde_json::from_str(data).unwrap(); 667 - // Check for metrics in the response 668 - if let Some(metrics_obj) = v.get("metrics") { 669 - metrics = serde_json::from_value(metrics_obj.clone()).ok(); 670 - } 671 - let model_text: Option<&str> = if run_args.memory { 672 - v["choices"][0]["delta"]["content"].as_str() 673 - } else { 674 - prev_response_id = serde_json::to_string(&v["id"])? 675 - .trim_matches('\"') 676 - .to_owned(); 746 + // //TODO: This will break if we ask the model to give an essay and all 747 + // let v: Value = serde_json::from_str(data).unwrap(); 748 + // // Check for metrics in the response 749 + // if let Some(metrics_obj) = v.get("metrics") { 750 + // metrics = serde_json::from_value(metrics_obj.clone()).ok(); 751 + // } 752 + // let model_text: Option<&str> = if run_args.memory { 753 + // v["choices"][0]["delta"]["content"].as_str() 754 + // } else { 755 + // prev_response_id = serde_json::to_string(&v["id"])? 756 + // .trim_matches('\"') 757 + // .to_owned(); 677 758 678 - if serde_json::to_string(&v["status"])?.contains("completed") { 679 - output_completed = true; 680 - } 759 + // if serde_json::to_string(&v["status"])?.contains("completed") { 760 + // output_completed = true; 761 + // } 681 762 682 - v["output"][0]["content"][0]["text"].as_str() 683 - }; 763 + // v["output"][0]["content"][0]["text"].as_str() 764 + // }; 684 765 685 - if let Some(delta) = model_text { 686 - if !run_args.memory { 687 - if delta.contains("**[Answer]**") { 688 - is_answer_start = true 689 - } 690 - if !output_completed { 691 - accumulated.push_str(delta); 692 - if !is_answer_start { 693 - print!("{}", delta.dimmed()); 694 - } else { 695 - print!("{}", delta); 696 - }; 697 - } 698 - } else { 699 - accumulated.push_str(delta); 700 - } 701 - use std::io::Write; 702 - std::io::stdout().flush().ok(); 703 - } 704 - } 705 - } 766 + // if let Some(delta) = model_text { 767 + // if !run_args.memory { 768 + // if delta.contains("**[Answer]**") { 769 + // is_answer_start = true 770 + // } 771 + // if !output_completed { 772 + // accumulated.push_str(delta); 773 + // if !is_answer_start { 774 + // print!("{}", delta.dimmed()); 775 + // } else { 776 + // print!("{}", delta); 777 + // }; 778 + // } 779 + // } else { 780 + // accumulated.push_str(delta); 781 + // } 782 + // use std::io::Write; 783 + // std::io::stdout().flush().ok(); 784 + // } 785 + // } 786 + // } 706 787 707 - Err(anyhow!("Result failed")) 708 - } 788 + // Err(anyhow!("Result failed")) 789 + // } 709 790 710 - fn convert_to_chat_response( 711 - content: &str, 712 - memory_mode: bool, 713 - prev_response_id: String, 714 - metrics: Option<BenchmarkMetrics>, 715 - ) -> ChatResponse { 716 - ChatResponse { 717 - reply: extract_reply(content, memory_mode), 718 - code: extract_python(content), 719 - prev_response_id, 720 - metrics, 721 - parent_chat_id: None, 722 - } 723 - } 791 + // fn convert_to_chat_response( 792 + // content: &str, 793 + // memory_mode: bool, 794 + // prev_response_id: String, 795 + // metrics: Option<BenchmarkMetrics>, 796 + // ) -> ChatResponse { 797 + // ChatResponse { 798 + // reply: extract_reply(content, memory_mode), 799 + // code: None, 800 + // prev_response_id, 801 + // metrics, 802 + // parent_chat_id: None, 803 + // } 804 + // } 724 805 725 - fn extract_reply(content: &str, memory_mode: bool) -> String { 726 - if !memory_mode && content.contains("**[Answer]**") { 727 - let list_a = content.split("**[Answer]**").collect::<Vec<&str>>(); 728 - list_a[1].to_owned() 729 - } else if !memory_mode { 730 - content.to_owned() 731 - } else if content.contains("<reply>") && content.contains("</reply>") { 732 - let list_a = content.split("<reply>").collect::<Vec<&str>>(); 733 - let list_b = list_a[1].split("</reply>").collect::<Vec<&str>>(); 734 - list_b[0].to_owned() 735 - } else { 736 - "".to_owned() 737 - } 738 - } 806 + // fn extract_reply(content: &str, memory_mode: bool) -> String { 807 + // if !memory_mode && content.contains("**[Answer]**") { 808 + // let list_a = content.split("**[Answer]**").collect::<Vec<&str>>(); 809 + // list_a[1].to_owned() 810 + // } else if !memory_mode { 811 + // content.to_owned() 812 + // } else if content.contains("<reply>") && content.contains("</reply>") { 813 + // let list_a = content.split("<reply>").collect::<Vec<&str>>(); 814 + // let list_b = list_a[1].split("</reply>").collect::<Vec<&str>>(); 815 + // list_b[0].to_owned() 816 + // } else { 817 + // "".to_owned() 818 + // } 819 + // } 739 820 740 821 fn extract_python(content: &str) -> String { 741 822 if content.contains("<python>") && content.contains("</python>") { ··· 852 933 let pi_process = Command::new(pi_exec_path) 853 934 .arg("--mode") 854 935 .arg("rpc") 855 - .arg("--no-session") 936 + // .arg("--no-session") 856 937 .env("PI_CODING_AGENT_DIR", pi_agent_dir) 857 938 .env("PI_OFFLINE", "true") 858 939 .stdin(Stdio::piped())
+1 -1
tiles/src/utils/mod.rs
··· 8 8 SystemTime::now() 9 9 .duration_since(UNIX_EPOCH) 10 10 .expect("time went backwards") 11 - .as_secs() 11 + .as_millis() as u64 12 12 } 13 13 14 14 pub fn test_logger() {