···1010use crate::runtime::mlx::ChatResponse;
1111use crate::utils::get_unix_time_now;
1212use anyhow::{Result, anyhow};
1313-use log::info;
1313+use log::{info, warn};
1414use rusqlite::types::FromSqlError;
1515use rusqlite::{Connection, params};
1616use tilekit::modelfile::Role;
···5151 created_at: u64,
5252 updated_at: u64,
5353 row_counter: i64,
5454+ session_id: String,
5555+}
5656+5757+#[derive(Debug, serde::Serialize, serde::Deserialize)]
5858+pub struct Session {
5959+ pub id: String,
6060+ name: String,
6161+ created_at: u64,
6262+ creator_id: String,
5463}
55645665type Responder<T> = oneshot::Sender<T>;
···7079 },
7180}
72817373-pub fn save_chat(
7474- conn: &Connection,
7575- user: &User,
7676- input: &str,
7777- chat_resp: Option<&ChatResponse>,
7878-) -> Result<Chats> {
8282+pub fn save_chat(conn: &Connection, user: &User, chat_resp: ChatResponse) -> Result<Chats> {
7983 let row_counter = get_last_row_counter(conn, &user.user_id)?;
8080- if let Some(chat_response) = chat_resp {
8181- let chat_resp_cloned = chat_response.clone();
8282-8383- let chat = Chats {
8484- id: Uuid::now_v7().to_string(),
8585- user_id: user.user_id.clone(),
8686- content: input.to_owned(),
8787- response_id: Some(chat_resp_cloned.prev_response_id),
8888- role: Role::Assistant,
8989- context_id: chat_resp_cloned.parent_chat_id,
9090- created_at: get_unix_time_now(),
9191- updated_at: get_unix_time_now(),
9292- row_counter: row_counter + 1,
9393- };
9494-9595- conn.execute("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", (&chat.id.to_string(), &chat.user_id, &chat.content, &chat.response_id, Into::<String>::into(chat.role), &chat.context_id, &chat.created_at.to_string(), &chat.updated_at.to_string(), &chat.row_counter))?;
9696-9797- Ok(chat)
9898- } else {
9999- let chat = Chats {
100100- id: Uuid::now_v7().to_string(),
101101- user_id: user.user_id.clone(),
102102- content: input.to_owned(),
103103- response_id: None,
104104- role: Role::User,
105105- context_id: None,
106106- created_at: get_unix_time_now(),
107107- updated_at: get_unix_time_now(),
108108- row_counter: row_counter + 1,
109109- };
8484+ let chat = Chats {
8585+ id: Uuid::now_v7().to_string(),
8686+ user_id: user.user_id.clone(),
8787+ content: chat_resp.input,
8888+ response_id: None,
8989+ role: chat_resp.role,
9090+ context_id: chat_resp.parent_chat_id,
9191+ created_at: get_unix_time_now(),
9292+ updated_at: get_unix_time_now(),
9393+ row_counter: row_counter + 1,
9494+ session_id: chat_resp.session_id,
9595+ };
11096111111- conn.execute("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", (&chat.id, &chat.user_id, &chat.content, &chat.response_id, Into::<String>::into(chat.role), &chat.context_id, &chat.created_at.to_string(), &chat.updated_at.to_string(), &chat.row_counter))?;
9797+ conn.execute("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter, session_id) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", (&chat.id.to_string(), &chat.user_id, &chat.content, &chat.response_id, Into::<String>::into(chat.role), &chat.context_id, &chat.created_at.to_string(), &chat.updated_at.to_string(), &chat.row_counter, &chat.session_id))?;
11298113113- Ok(chat)
114114- }
9999+ Ok(chat)
115100}
116101117102/// Returns the `id` of the last entry of the given user_id
···145130}
146131/// Return list of rows for the given `user_id` since `last_row_counter`
147132pub fn get_delta(conn: &Connection, user_id: &str, last_row_couter: i64) -> Result<Vec<Chats>> {
148148- let mut stmt = conn.prepare("select id, user_id, content, resp_id, role, context_id, created_at, updated_at , row_counter from chats where user_id = ?1 and row_counter > ?2 order by id")?;
133133+ let mut stmt = conn.prepare("select id, user_id, content, resp_id, role, context_id, created_at, updated_at , row_counter, session_id from chats where user_id = ?1 and row_counter > ?2 order by id")?;
149134150135 let chat_rows = stmt.query_map(params![user_id, last_row_couter], |row| {
151136 let id: String = row.get(0)?;
···164149 created_at: created_at as u64,
165150 updated_at: updated_at as u64,
166151 row_counter: row.get(8)?,
152152+ session_id: row.get(9)?,
167153 })
168154 })?;
169155···184170185171 let txn = chat_conn.transaction()?;
186172 {
187187- let mut stmt = txn.prepare("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)")?;
173173+ let mut stmt = txn.prepare("insert into chats(id, user_id, content, resp_id, role, context_id, created_at, updated_at, row_counter, session_id) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)")?;
188174189175 for chat in delta_chats {
190176 match stmt.execute(params![
···197183 &chat.created_at.to_string(),
198184 &chat.updated_at.to_string(),
199185 &chat.row_counter,
186186+ &chat.session_id
200187 ]) {
201188 Err(rusqlite::Error::SqliteFailure(_, Some(reason)))
202189 if reason == "UNIQUE constraint failed: chats.id" =>
···269256 tx
270257}
271258259259+pub fn create_session(conn: &Connection, id: &str, name: &str, user_id: &str) -> Result<Session> {
260260+ // log a warning if session already exists, and skip the conflict
261261+262262+ let mut stmt = conn.prepare(
263263+ "insert into sessions(id, name, creator_id, created_at) values (?1, ?2, ?3, ?4)",
264264+ )?;
265265+266266+ match stmt.execute(params![
267267+ id.to_owned(),
268268+ name.to_owned(),
269269+ user_id.to_owned(),
270270+ get_unix_time_now() as f64
271271+ ]) {
272272+ Ok(_res) => {
273273+ let sesh = fetch_session(conn, id)?;
274274+ Ok(sesh)
275275+ }
276276+ Err(rusqlite::Error::SqliteFailure(_, Some(reason)))
277277+ if reason == "UNIQUE constraint failed: sessions.id" =>
278278+ {
279279+ warn!("Session entry already exists, skipping");
280280+ let sesh = fetch_session(conn, id)?;
281281+ Ok(sesh)
282282+ }
283283+ Err(err) => Err(anyhow!("Err inserting due to {}", err)),
284284+ }
285285+}
286286+287287+fn fetch_session(conn: &Connection, session_id: &str) -> Result<Session> {
288288+ let sesh = conn.query_row(
289289+ "SELECT id, name, creator_id, created_at FROM sessions WHERE id = ?1",
290290+ [session_id],
291291+ |row| {
292292+ Ok(Session {
293293+ id: row.get(0)?,
294294+ name: row.get(1)?,
295295+ creator_id: row.get(2)?,
296296+ created_at: row.get::<usize, f64>(3)? as u64,
297297+ })
298298+ },
299299+ )?;
300300+ Ok(sesh)
301301+}
272302fn encode_delta_to_bytes(delta_chats: &Vec<Chats>) -> Vec<u8> {
273303 postcard::to_stdvec(delta_chats).expect("Failed to convert to bytes with postcard")
274304}
···290320 core::{
291321 accounts::{ACCOUNT, User},
292322 chats::{
293293- apply_delta, decode_delta_from_bytes, encode_delta_to_bytes, get_delta,
294294- get_last_row_counter, save_chat,
323323+ apply_delta, create_session, decode_delta_from_bytes, encode_delta_to_bytes,
324324+ get_delta, get_last_row_counter, save_chat,
295325 },
296326 },
297327 runtime::mlx::ChatResponse,
···303333 let conn = setup_db_schema();
304334 let user = create_user();
305335 let input = "2+2";
306306- let chat = save_chat(&conn, &user, input, None).expect("chat should be saved");
336336+337337+ let chat_response = ChatResponse {
338338+ input: input.to_owned(),
339339+ session_id: String::from("session_abc"),
340340+ role: Role::User,
341341+ code: None,
342342+ prev_response_id: None,
343343+ parent_chat_id: None,
344344+ metrics: None,
345345+ };
346346+ let chat = save_chat(&conn, &user, chat_response).expect("chat should be saved");
307347308348 assert_eq!(chat.user_id, user.user_id);
309349 assert!(chat.response_id.is_none());
···322362 let conn = setup_db_schema();
323363 let user = create_user();
324364 let parent_chat_id = Uuid::now_v7().to_string();
325325- let chat_resp = ChatResponse {
326326- reply: "reply".to_owned(),
327327- code: "code".to_owned(),
328328- prev_response_id: String::from("resp_prev"),
365365+ let input = "2+2";
366366+ let chat_response = ChatResponse {
367367+ input: input.to_owned(),
368368+ session_id: String::from("session_abc"),
369369+ role: Role::Assistant,
370370+ code: None,
371371+ prev_response_id: None,
329372 parent_chat_id: Some(parent_chat_id.clone()),
330373 metrics: None,
331374 };
332332- let input = "2+2";
333333- let chat = save_chat(&conn, &user, input, Some(&chat_resp)).expect("chat should be saved");
375375+ let chat = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
334376335377 assert_eq!(chat.user_id, user.user_id);
336336- assert_eq!(chat.response_id.as_deref(), Some("resp_prev"));
337378 assert_eq!(chat.context_id, Some(parent_chat_id.clone()));
338379339380 let saved = fetch_saved_chat_row(&conn, &chat.id);
340381 assert_eq!(saved.content, input);
341341- assert_eq!(saved.resp_id, Some(String::from("resp_prev")));
342382 assert_eq!(saved.role, Into::<String>::into(Role::Assistant));
343383 assert_eq!(saved.user_id, user.user_id);
344384 assert_eq!(saved.context_id, Some(parent_chat_id.clone()));
···348388 fn test_response_without_parent_chat_id_saves_nil_context() {
349389 let conn = setup_db_schema();
350390 let user = create_user();
351351- let chat_resp = ChatResponse {
352352- reply: "reply".to_owned(),
353353- code: "code".to_owned(),
354354- prev_response_id: String::from("resp_prev"),
391391+ let chat_response = ChatResponse {
392392+ input: "".to_owned(),
393393+ session_id: String::from("session_abc"),
394394+ role: Role::Assistant,
395395+ code: None,
396396+ prev_response_id: Some(Uuid::now_v7().to_string()),
355397 parent_chat_id: Some(Uuid::now_v7().to_string()),
356398 metrics: None,
357399 };
358400359359- let chat =
360360- save_chat(&conn, &user, "hello", Some(&chat_resp)).expect("chat should be saved");
401401+ let chat = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
361402362403 assert!(chat.context_id.is_some());
363404 let saved = fetch_saved_chat_row(&conn, &chat.id);
···369410 fn test_empty_input_is_saved() {
370411 let conn = setup_db_schema();
371412 let user = create_user();
372372-373373- let chat = save_chat(&conn, &user, "", None).expect("empty content should still be saved");
413413+ let chat_response = ChatResponse {
414414+ input: "".to_owned(),
415415+ session_id: String::from("session_abc"),
416416+ role: Role::User,
417417+ code: None,
418418+ prev_response_id: None,
419419+ parent_chat_id: None,
420420+ metrics: None,
421421+ };
422422+ let chat =
423423+ save_chat(&conn, &user, chat_response).expect("empty content should still be saved");
374424375425 let saved = fetch_saved_chat_row(&conn, &chat.id);
376426 assert_eq!(saved.content, "");
···381431 fn test_save_chat_errors_when_table_missing() {
382432 let conn = Connection::open_in_memory().expect("in-memory db should open");
383433 let user = create_user();
384384-385385- let result = save_chat(&conn, &user, "2+2", None);
434434+ let chat_response = ChatResponse {
435435+ input: "".to_owned(),
436436+ session_id: String::from("session_abc"),
437437+ role: Role::User,
438438+ code: None,
439439+ prev_response_id: None,
440440+ parent_chat_id: None,
441441+ metrics: None,
442442+ };
443443+ let result = save_chat(&conn, &user, chat_response);
386444387445 assert!(result.is_err());
388446 }
···391449 fn test_last_row_counter() {
392450 let conn = setup_db_schema();
393451 let user = create_user();
394394- let input = "2+2";
395395- let chat = save_chat(&conn, &user, input, None).expect("chat should be saved");
452452+ let chat_response = ChatResponse {
453453+ input: "".to_owned(),
454454+ session_id: String::from("session_abc"),
455455+ role: Role::User,
456456+ code: None,
457457+ prev_response_id: None,
458458+ parent_chat_id: None,
459459+ metrics: None,
460460+ };
461461+ let chat = save_chat(&conn, &user, chat_response).expect("chat should be saved");
396462397463 assert_eq!(chat.user_id, user.user_id);
398464 assert!(chat.response_id.is_none());
···415481 let conn = setup_db_schema();
416482 let user = create_user();
417483 let input = "2+2";
418418- let chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved");
419419- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
420420- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
421421- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
484484+ let chat_response = ChatResponse {
485485+ input: input.to_owned(),
486486+ session_id: String::from("session_abc"),
487487+ role: Role::User,
488488+ code: None,
489489+ prev_response_id: None,
490490+ parent_chat_id: None,
491491+ metrics: None,
492492+ };
493493+ let chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
494494+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
495495+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
496496+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
422497423498 let rows = get_delta(&conn, &user.user_id, chat_1.row_counter).unwrap();
424499 assert_eq!(rows.len(), 3);
···429504 let conn = setup_db_schema();
430505 let user = create_user();
431506 let input = "2+2";
432432- let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved");
433433- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
434434- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
435435- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
507507+ let chat_response = ChatResponse {
508508+ input: input.to_owned(),
509509+ session_id: String::from("session_abc"),
510510+ role: Role::User,
511511+ code: None,
512512+ prev_response_id: None,
513513+ parent_chat_id: None,
514514+ metrics: None,
515515+ };
516516+ let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
517517+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
518518+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
519519+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
436520437521 let rows = get_delta(&conn, &user.user_id, 0).unwrap();
438522 assert_eq!(rows.len(), 4);
···443527 let conn = setup_db_schema();
444528 let user = create_user();
445529 let input = "2+2";
446446- let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved");
447447- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
448448- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
449449- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
530530+ let chat_response = ChatResponse {
531531+ input: input.to_owned(),
532532+ session_id: String::from("session_abc"),
533533+ role: Role::User,
534534+ code: None,
535535+ prev_response_id: None,
536536+ parent_chat_id: None,
537537+ metrics: None,
538538+ };
539539+ let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
540540+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
541541+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
542542+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
450543451544 let rows = get_delta(&conn, "", 0).unwrap();
452545 assert_eq!(rows.len(), 0);
···458551 let mut conn_2 = setup_db_schema();
459552 let user = create_user();
460553 let input = "2+2";
461461- let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved");
462462- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
463463- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
464464- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
554554+ let chat_response = ChatResponse {
555555+ input: input.to_owned(),
556556+ session_id: String::from("session_abc"),
557557+ role: Role::User,
558558+ code: None,
559559+ prev_response_id: None,
560560+ parent_chat_id: None,
561561+ metrics: None,
562562+ };
563563+ let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
564564+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
565565+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
566566+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
465567466568 let rows = get_delta(&conn, &user.user_id, 0).unwrap();
467569 assert_eq!(rows.len(), 4);
···476578 let mut conn_2 = setup_db_schema();
477579 let user = create_user();
478580 let input = "2+2";
479479- let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved");
480480- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
481481- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
482482- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
581581+ let chat_response = ChatResponse {
582582+ input: input.to_owned(),
583583+ session_id: String::from("session_abc"),
584584+ role: Role::User,
585585+ code: None,
586586+ prev_response_id: None,
587587+ parent_chat_id: None,
588588+ metrics: None,
589589+ };
590590+ let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
591591+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
592592+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
593593+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
483594484595 let rows = get_delta(&conn, &user.user_id, 0).unwrap();
485596 assert_eq!(rows.len(), 4);
···496607 let mut conn_2 = setup_db_schema();
497608 let user = create_user();
498609 let input = "2+2";
499499- let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved");
500500- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
501501- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
502502- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
610610+ let chat_response = ChatResponse {
611611+ input: input.to_owned(),
612612+ session_id: String::from("session_abc"),
613613+ role: Role::User,
614614+ code: None,
615615+ prev_response_id: None,
616616+ parent_chat_id: None,
617617+ metrics: None,
618618+ };
619619+ let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
620620+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
621621+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
622622+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
503623504624 let rows = get_delta(&conn, &user.user_id, 4).unwrap();
505625 assert_eq!(rows.len(), 0);
···516636 let mut _conn_2 = setup_db_schema();
517637 let user = create_user();
518638 let input = "2+2";
519519- let chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved");
520520- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
521521- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
522522- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
639639+ let chat_response = ChatResponse {
640640+ input: input.to_owned(),
641641+ session_id: String::from("session_abc"),
642642+ role: Role::User,
643643+ code: None,
644644+ prev_response_id: None,
645645+ parent_chat_id: None,
646646+ metrics: None,
647647+ };
648648+ let chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
649649+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
650650+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
651651+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
523652 let rows = get_delta(&conn, &user.user_id, chat_1.row_counter).unwrap();
524653 assert_eq!(rows.len(), 3);
525654 }
···531660 let mut conn_2 = setup_db_schema();
532661 let user = create_user();
533662 let input = "2+2";
534534- let _chat_1 = save_chat(&conn, &user, input, None).expect("chat should be saved");
535535- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
536536- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
537537- let _ = save_chat(&conn, &user, input, None).expect("chat should be saved");
663663+ let chat_response = ChatResponse {
664664+ input: input.to_owned(),
665665+ session_id: String::from("session_abc"),
666666+ role: Role::User,
667667+ code: None,
668668+ prev_response_id: None,
669669+ parent_chat_id: None,
670670+ metrics: None,
671671+ };
672672+ let _chat_1 = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
673673+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
674674+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
675675+ let _ = save_chat(&conn, &user, chat_response.clone()).expect("chat should be saved");
538676539677 let rows = get_delta(&conn, &user.user_id, 0).unwrap();
540678 assert_eq!(rows.len(), 4);
···558696559697 // Node user A adds stuff
560698 let input = "2+2";
561561- let _chat_1 = save_chat(&conn, &user_a, input, None).expect("chat should be saved");
562562- let _ = save_chat(&conn, &user_a, input, None).expect("chat should be saved");
563563- let _ = save_chat(&conn, &user_a, input, None).expect("chat should be saved");
564564- let _ = save_chat(&conn, &user_a, input, None).expect("chat should be saved");
699699+ let chat_response = ChatResponse {
700700+ input: input.to_owned(),
701701+ session_id: String::from("session_abc"),
702702+ role: Role::User,
703703+ code: None,
704704+ prev_response_id: None,
705705+ parent_chat_id: None,
706706+ metrics: None,
707707+ };
708708+ let _chat_1 =
709709+ save_chat(&conn, &user_a, chat_response.clone()).expect("chat should be saved");
710710+ let _ = save_chat(&conn, &user_a, chat_response.clone()).expect("chat should be saved");
711711+ let _ = save_chat(&conn, &user_a, chat_response.clone()).expect("chat should be saved");
712712+ let _ = save_chat(&conn, &user_a, chat_response.clone()).expect("chat should be saved");
565713566714 // Node user B adds stuff
567715 let input = "4+4";
568568- let _chat_1 = save_chat(&conn_2, &user_b, input, None).expect("chat should be saved");
569569- let _ = save_chat(&conn_2, &user_b, input, None).expect("chat should be saved");
570570- let _ = save_chat(&conn_2, &user_b, input, None).expect("chat should be saved");
571571- let _ = save_chat(&conn_2, &user_b, input, None).expect("chat should be saved");
716716+ let chat_response = ChatResponse {
717717+ input: input.to_owned(),
718718+ session_id: String::from("session_abc"),
719719+ role: Role::User,
720720+ code: None,
721721+ prev_response_id: None,
722722+ parent_chat_id: None,
723723+ metrics: None,
724724+ };
725725+ let _chat_1 =
726726+ save_chat(&conn_2, &user_b, chat_response.clone()).expect("chat should be saved");
727727+ let _ = save_chat(&conn_2, &user_b, chat_response.clone()).expect("chat should be saved");
728728+ let _ = save_chat(&conn_2, &user_b, chat_response.clone()).expect("chat should be saved");
729729+ let _ = save_chat(&conn_2, &user_b, chat_response.clone()).expect("chat should be saved");
572730573731 // Node A wants to sync with Node B
574732···638796 assert_eq!(user_a_rows, user_b_rows);
639797 }
640798799799+ #[test]
800800+ fn test_valid_input_create_session() {
801801+ let conn = setup_db_schema();
802802+ let user = create_user();
803803+804804+ let session = create_session(&conn, "id-1", "sesh-1", &user.user_id).unwrap();
805805+ assert_eq!(user.user_id, session.creator_id);
806806+ }
807807+808808+ #[test]
809809+ fn test_duplicate_id_create_session() {
810810+ let conn = setup_db_schema();
811811+ let user = create_user();
812812+813813+ let session = create_session(&conn, "id-1", "sesh-1", &user.user_id).unwrap();
814814+ assert_eq!(user.user_id, session.creator_id);
815815+816816+ let session_2 = create_session(&conn, "id-1", "sesh-1", &user.user_id).unwrap();
817817+ assert_eq!(user.user_id, session_2.creator_id);
818818+ }
819819+641820 struct SavedChatRow {
642821 content: String,
643822 resp_id: Option<String>,
···718897 )
719898 .unwrap();
720899900900+ conn.execute(
901901+ "CREATE TABLE IF NOT EXISTS sessions (
902902+ id TEXT PRIMARY KEY,
903903+ name TEXT NOT NULL,
904904+ creator_id TEXT NOT NULL,
905905+ created_at INTEGER NOT NULL
906906+ )",
907907+ [],
908908+ )
909909+ .unwrap();
721910 conn
722911 }
723912}
+8
tiles/src/core/storage/db.rs
···7070 ALTER TABLE CHATS ADD COLUMN session_id TEXT;
7171 ",
7272 ),
7373+ M::up(
7474+ "CREATE TABLE IF NOT EXISTS sessions (
7575+ id TEXT PRIMARY KEY,
7676+ name TEXT NOT NULL,
7777+ creator_id TEXT NOT NULL,
7878+ created_at INTEGER NOT NULL
7979+ )",
8080+ ),
7381];
74827583const CHATS_MIGRATIONS: Migrations = Migrations::from_slice(CHATS_MIGRATION_ARRAY);
+243-162
tiles/src/runtime/mlx.rs
···11use crate::core::accounts::{User, get_current_user};
22-use crate::core::chats::{Message, save_chat};
22+use crate::core::chats::{Message, create_session, save_chat};
33use crate::core::storage::db::Dbconn;
44use crate::runtime::RunArgs;
55-use crate::utils::config::{
66- ConfigProvider, DefaultProvider, get_memory_path, get_model_cache, update_current_model,
77-};
55+use crate::utils::config::{ConfigProvider, DefaultProvider, get_memory_path, get_model_cache};
86use crate::utils::hf_model_downloader::*;
97use anyhow::{Context, Result, anyhow};
1010-use futures_util::StreamExt;
1111-use owo_colors::OwoColorize;
88+use log::info;
129use reqwest::{Client, StatusCode};
1313-use rusqlite::Connection;
1410use rustyline::completion::Completer;
1511use rustyline::highlight::Highlighter;
1612use rustyline::hint::Hinter;
···1915use rustyline::{Config, Editor, Helper};
2016use serde::{Deserialize, Serialize};
2117use serde_json::{Value, json};
2222-use std::cell::{Cell, RefCell};
2323-use std::collections::HashMap;
2418use std::fs::OpenOptions;
2525-use std::io::{BufRead, BufReader, Read, Write};
1919+use std::io::{BufRead, BufReader, Write};
2620use std::path::PathBuf;
2727-use std::process::{Child, ChildStdout, Command};
2121+use std::process::{Child, Command};
2822use std::process::{ChildStdin, Stdio};
2929-use std::rc::Rc;
3023use std::time::Duration;
3124use tilekit::modelfile::Modelfile;
3225use tilekit::modelfile::Role;
···59526053#[derive(Clone, Debug)]
6154pub struct ChatResponse {
6262- // think: String,
6363- pub reply: String,
6464- pub code: String,
6565- pub prev_response_id: String,
5555+ // text content
5656+ pub input: String,
5757+ pub session_id: String,
5858+ pub role: Role,
5959+ pub code: Option<String>,
6060+ // deprecated, will remove soon
6161+ pub prev_response_id: Option<String>,
6662 pub parent_chat_id: Option<String>,
6763 pub metrics: Option<BenchmarkMetrics>,
6864}
···7874 MessageUpdate(PiMessageUpdate),
7975 #[serde(rename = "agent_end")]
8076 AgentEnd,
7777+ #[serde(rename = "turn_end")]
7878+ TurnEnd(PiTurnEndEvent),
8179 #[serde[other]]
8280 Unknown,
8381}
···110108 command: CommandType,
111109 success: bool,
112110 data: Option<Value>,
111111+}
112112+113113+#[derive(Serialize, Deserialize, Debug)]
114114+struct PiTurnEndEvent {
115115+ message: PiTurnEndEventMsg,
116116+}
117117+118118+#[derive(Serialize, Deserialize, Debug)]
119119+struct PiTurnEndEventMsg {
120120+ role: String,
121121+ content: Vec<PiMsgContent>,
122122+}
123123+124124+#[derive(Serialize, Deserialize, Debug)]
125125+struct PiMsgContent {
126126+ r#type: String,
127127+ text: String,
113128}
114129115130impl Default for MLXRuntime {
···324339 let mut conversations: Vec<Message> = vec![];
325340326341 let mut pi_process = start_pi_rpc()?;
327327-342342+ let mut session_id = String::new();
328343 let pi_stdin = pi_process.stdin.as_mut().unwrap();
329344 let mut stdout = pi_process.stdout.take().expect("stdout");
330330- // let mut stdout: Cell<ChildStdout> = Cell::new();
345345+ let inti_cmd_payload = get_command_payload(CommandType::State);
346346+ send_to_pi(pi_stdin, inti_cmd_payload).inspect_err(|_e| eprintln!("send pi failed"))?;
347347+348348+ //TODO: Refactor session_id fetching
349349+ let mut pi_session_state = String::new();
350350+ let mut reader = BufReader::new(&mut stdout);
351351+ let _ = reader
352352+ .read_line(&mut pi_session_state)
353353+ .context("Failed reading pi session state")?;
354354+ println!("{}", pi_session_state);
355355+ let response: PiResponse = serde_json::from_str(&pi_session_state)?;
356356+ if let PiResponse::Response(msg) = response {
357357+ let state: GetStateData =
358358+ serde_json::from_value(msg.data.expect("get state parsing failed"))?;
359359+ session_id = state.session_id;
360360+ info!("Current session: {}", session_id);
361361+ }
362362+331363 loop {
332364 let readline = editor.readline(">>> ");
333365 let input = match readline {
334366 Ok(line) => line.trim().to_string(),
335367 Err(_) => {
368368+ //TODO: Panic when entering another prompt after ctr-l C
369369+ // called `Result::unwrap()` on an `Err` value: Os { code: 32, kind: BrokenPipe, message: "Broken pipe" }
370370+ //
336371 // User pressed Ctrl+C or Ctrl+D
337372 let end_payload = json!({
338373 "type": "abort",
···403438 };
404439 let mut is_agent_streaming: bool = false;
405440 let reader = BufReader::new(&mut stdout);
406406-441441+ let mut session_turn_count = 0;
442442+ let mut last_chat_id: String = "".to_owned();
407443 for line in reader.lines() {
408444 //TODO: handle the unwrap
409445 let line = line?;
···411447412448 match response {
413449 PiResponse::AgentStart => {
414414- // agent streaming started
415415- is_agent_streaming = true
450450+ info!("\nAgent start\n");
416451 }
417452 PiResponse::MessageUpdate(msg_update) => {
418453 if msg_update.assistant_message_event.r#type == "text_delta"
419454 && msg_update.assistant_message_event.delta.is_some()
420455 {
456456+ // TODO: Can we remove the unwrap
421457 print!("{}", msg_update.assistant_message_event.delta.unwrap());
422458 // TODO: maybe can optimize check print! doc
423459 use std::io::Write;
···425461 }
426462 }
427463 PiResponse::AgentEnd => {
428428- // agent streaming stopeed
429429- is_agent_streaming = false;
464464+ info!("\nAgent End\n");
430465 break;
431466 }
467467+ PiResponse::TurnEnd(turn_event) => {
468468+ info!("\nTurn end\n");
469469+ session_turn_count += 1;
470470+471471+ // on agent end create a new session entry, only for the
472472+ // first time
473473+ if session_turn_count == 1 {
474474+ info!("Created session {}", session_id);
475475+ create_session(&db_conn.chat, &session_id, "dummy", ¤t_user.user_id)?;
476476+ }
477477+ let parent_chat_id = if session_turn_count == 1 {
478478+ None
479479+ } else {
480480+ Some(last_chat_id.clone())
481481+ };
482482+ let chat_response = ChatResponse {
483483+ input: input.clone(),
484484+ session_id: session_id.clone(),
485485+ role: Role::User,
486486+ code: None,
487487+ prev_response_id: None,
488488+ parent_chat_id,
489489+ metrics: None,
490490+ };
491491+ let prompt_chat = save_chat(&db_conn.chat, ¤t_user, chat_response)?;
492492+ last_chat_id = prompt_chat.id;
493493+ if turn_event.message.role == "assistant" {
494494+ let mut content = turn_event.message.content;
495495+ if let Some(msg) = content.pop() {
496496+ let chat_response = ChatResponse {
497497+ input: msg.text.clone(),
498498+ session_id: session_id.clone(),
499499+ role: Role::Assistant,
500500+ code: None,
501501+ prev_response_id: None,
502502+ parent_chat_id: Some(last_chat_id.clone()),
503503+ metrics: None,
504504+ };
505505+ let chat = save_chat(&db_conn.chat, ¤t_user, chat_response)?;
506506+ last_chat_id = chat.id;
507507+ }
508508+ } else {
509509+ info!("Not handling {} role now", turn_event.message.role);
510510+ }
511511+ }
432512 PiResponse::Response(response_msg) => {
433513 if response_msg.success {
434514 match response_msg.command {
435515 CommandType::Unknown => {
516516+ println!("{}", line);
436517 continue;
437518 }
438519 cmd => process_command(cmd, response_msg.data)?,
···586667}
587668588669//TODO: Have 2 separate chat functions for memory and non-memory
589589-#[allow(clippy::too_many_arguments)]
590590-async fn chat(
591591- input: &str,
592592- modelfile: &Modelfile,
593593- chat_start: bool,
594594- python_code: &str,
595595- g_reply: &str,
596596- run_args: &RunArgs,
597597- prev_response_id: &str,
598598- conn: &Connection,
599599- user: &User,
600600- conversations: &[Message],
601601-) -> Result<ChatResponse> {
602602- let client = Client::new();
603603- let modelname = modelfile
604604- .from
605605- .clone()
606606- .ok_or_else(|| anyhow!("Failed to get model name"))?;
607607- let prompt = modelfile.system.clone().unwrap_or("".to_owned());
608608- let convo_input = create_chat_input(input, prompt.as_str(), conversations);
609609- let body = json!({
610610- "model": modelname,
611611- "input": convo_input,
612612- "reasoning": {"effort": "medium"},
613613- "chat_start": chat_start,
614614- "stream": true,
615615- "previous_response_id": prev_response_id,
616616- "python_code": python_code,
617617- "messages": [{"role": "assistant", "content": g_reply}, {"role": "user", "content": input}]
618618- });
670670+// #[allow(clippy::too_many_arguments)]
671671+// async fn chat(
672672+// input: &str,
673673+// modelfile: &Modelfile,
674674+// chat_start: bool,
675675+// python_code: &str,
676676+// g_reply: &str,
677677+// run_args: &RunArgs,
678678+// prev_response_id: &str,
679679+// conn: &Connection,
680680+// user: &User,
681681+// conversations: &[Message],
682682+// ) -> Result<ChatResponse> {
683683+// let client = Client::new();
684684+// let modelname = modelfile
685685+// .from
686686+// .clone()
687687+// .ok_or_else(|| anyhow!("Failed to get model name"))?;
688688+// let prompt = modelfile.system.clone().unwrap_or("".to_owned());
689689+// let convo_input = create_chat_input(input, prompt.as_str(), conversations);
690690+// let body = json!({
691691+// "model": modelname,
692692+// "input": convo_input,
693693+// "reasoning": {"effort": "medium"},
694694+// "chat_start": chat_start,
695695+// "stream": true,
696696+// "previous_response_id": prev_response_id,
697697+// "python_code": python_code,
698698+// "messages": [{"role": "assistant", "content": g_reply}, {"role": "user", "content": input}]
699699+// });
619700620620- let memory_body = json!({
621621- "model": modelname,
622622- "input": input,
623623- "chat_start": chat_start,
624624- "stream": true,
625625- "python_code": python_code,
626626- "messages": [{"role": "assistant", "content": g_reply}, {"role": "user", "content": input}]
701701+// let memory_body = json!({
702702+// "model": modelname,
703703+// "input": input,
704704+// "chat_start": chat_start,
705705+// "stream": true,
706706+// "python_code": python_code,
707707+// "messages": [{"role": "assistant", "content": g_reply}, {"role": "user", "content": input}]
627708628628- });
629629- let res = if run_args.memory {
630630- let api_url = "http://127.0.0.1:6969/v1/chat/completions";
631631- client.post(api_url).json(&memory_body).send().await?
632632- } else {
633633- let api_url = "http://127.0.0.1:6969/v1/responses";
634634- client.post(api_url).json(&body).send().await?
635635- };
709709+// });
710710+// let res = if run_args.memory {
711711+// let api_url = "http://127.0.0.1:6969/v1/chat/completions";
712712+// client.post(api_url).json(&memory_body).send().await?
713713+// } else {
714714+// let api_url = "http://127.0.0.1:6969/v1/responses";
715715+// client.post(api_url).json(&body).send().await?
716716+// };
636717637637- let chat = save_chat(conn, user, input, None)?;
638638- let mut stream = res.bytes_stream();
639639- let mut accumulated = String::new();
640640- let mut metrics: Option<BenchmarkMetrics> = None;
641641- let mut is_answer_start = false;
642642- let mut prev_response_id: String = String::from("");
643643- let mut output_completed: bool = false;
644644- while let Some(chunk) = stream.next().await {
645645- let chunk = chunk?;
646646- let s = String::from_utf8_lossy(&chunk);
647647- for line in s.lines() {
648648- if !line.starts_with("data: ") {
649649- continue;
650650- }
718718+// let chat = save_chat(conn, user, input, None)?;
719719+// let mut stream = res.bytes_stream();
720720+// let mut accumulated = String::new();
721721+// let mut metrics: Option<BenchmarkMetrics> = None;
722722+// let mut is_answer_start = false;
723723+// let mut prev_response_id: String = String::from("");
724724+// let mut output_completed: bool = false;
725725+// while let Some(chunk) = stream.next().await {
726726+// let chunk = chunk?;
727727+// let s = String::from_utf8_lossy(&chunk);
728728+// for line in s.lines() {
729729+// if !line.starts_with("data: ") {
730730+// continue;
731731+// }
651732652652- let data = line.trim_start_matches("data: ");
733733+// let data = line.trim_start_matches("data: ");
653734654654- if data == "[DONE]" {
655655- let mut chat_resp = convert_to_chat_response(
656656- &accumulated,
657657- run_args.memory,
658658- prev_response_id,
659659- metrics,
660660- );
661661- chat_resp.parent_chat_id = Some(chat.id);
662662- return Ok(chat_resp);
663663- }
735735+// if data == "[DONE]" {
736736+// let mut chat_resp = convert_to_chat_response(
737737+// &accumulated,
738738+// run_args.memory,
739739+// prev_response_id,
740740+// metrics,
741741+// );
742742+// chat_resp.parent_chat_id = Some(chat.id);
743743+// return Ok(chat_resp);
744744+// }
664745665665- //TODO: This will break if we ask the model to give an essay and all
666666- let v: Value = serde_json::from_str(data).unwrap();
667667- // Check for metrics in the response
668668- if let Some(metrics_obj) = v.get("metrics") {
669669- metrics = serde_json::from_value(metrics_obj.clone()).ok();
670670- }
671671- let model_text: Option<&str> = if run_args.memory {
672672- v["choices"][0]["delta"]["content"].as_str()
673673- } else {
674674- prev_response_id = serde_json::to_string(&v["id"])?
675675- .trim_matches('\"')
676676- .to_owned();
746746+// //TODO: This will break if we ask the model to give an essay and all
747747+// let v: Value = serde_json::from_str(data).unwrap();
748748+// // Check for metrics in the response
749749+// if let Some(metrics_obj) = v.get("metrics") {
750750+// metrics = serde_json::from_value(metrics_obj.clone()).ok();
751751+// }
752752+// let model_text: Option<&str> = if run_args.memory {
753753+// v["choices"][0]["delta"]["content"].as_str()
754754+// } else {
755755+// prev_response_id = serde_json::to_string(&v["id"])?
756756+// .trim_matches('\"')
757757+// .to_owned();
677758678678- if serde_json::to_string(&v["status"])?.contains("completed") {
679679- output_completed = true;
680680- }
759759+// if serde_json::to_string(&v["status"])?.contains("completed") {
760760+// output_completed = true;
761761+// }
681762682682- v["output"][0]["content"][0]["text"].as_str()
683683- };
763763+// v["output"][0]["content"][0]["text"].as_str()
764764+// };
684765685685- if let Some(delta) = model_text {
686686- if !run_args.memory {
687687- if delta.contains("**[Answer]**") {
688688- is_answer_start = true
689689- }
690690- if !output_completed {
691691- accumulated.push_str(delta);
692692- if !is_answer_start {
693693- print!("{}", delta.dimmed());
694694- } else {
695695- print!("{}", delta);
696696- };
697697- }
698698- } else {
699699- accumulated.push_str(delta);
700700- }
701701- use std::io::Write;
702702- std::io::stdout().flush().ok();
703703- }
704704- }
705705- }
766766+// if let Some(delta) = model_text {
767767+// if !run_args.memory {
768768+// if delta.contains("**[Answer]**") {
769769+// is_answer_start = true
770770+// }
771771+// if !output_completed {
772772+// accumulated.push_str(delta);
773773+// if !is_answer_start {
774774+// print!("{}", delta.dimmed());
775775+// } else {
776776+// print!("{}", delta);
777777+// };
778778+// }
779779+// } else {
780780+// accumulated.push_str(delta);
781781+// }
782782+// use std::io::Write;
783783+// std::io::stdout().flush().ok();
784784+// }
785785+// }
786786+// }
706787707707- Err(anyhow!("Result failed"))
708708-}
788788+// Err(anyhow!("Result failed"))
789789+// }
709790710710-fn convert_to_chat_response(
711711- content: &str,
712712- memory_mode: bool,
713713- prev_response_id: String,
714714- metrics: Option<BenchmarkMetrics>,
715715-) -> ChatResponse {
716716- ChatResponse {
717717- reply: extract_reply(content, memory_mode),
718718- code: extract_python(content),
719719- prev_response_id,
720720- metrics,
721721- parent_chat_id: None,
722722- }
723723-}
791791+// fn convert_to_chat_response(
792792+// content: &str,
793793+// memory_mode: bool,
794794+// prev_response_id: String,
795795+// metrics: Option<BenchmarkMetrics>,
796796+// ) -> ChatResponse {
797797+// ChatResponse {
798798+// reply: extract_reply(content, memory_mode),
799799+// code: None,
800800+// prev_response_id,
801801+// metrics,
802802+// parent_chat_id: None,
803803+// }
804804+// }
724805725725-fn extract_reply(content: &str, memory_mode: bool) -> String {
726726- if !memory_mode && content.contains("**[Answer]**") {
727727- let list_a = content.split("**[Answer]**").collect::<Vec<&str>>();
728728- list_a[1].to_owned()
729729- } else if !memory_mode {
730730- content.to_owned()
731731- } else if content.contains("<reply>") && content.contains("</reply>") {
732732- let list_a = content.split("<reply>").collect::<Vec<&str>>();
733733- let list_b = list_a[1].split("</reply>").collect::<Vec<&str>>();
734734- list_b[0].to_owned()
735735- } else {
736736- "".to_owned()
737737- }
738738-}
806806+// fn extract_reply(content: &str, memory_mode: bool) -> String {
807807+// if !memory_mode && content.contains("**[Answer]**") {
808808+// let list_a = content.split("**[Answer]**").collect::<Vec<&str>>();
809809+// list_a[1].to_owned()
810810+// } else if !memory_mode {
811811+// content.to_owned()
812812+// } else if content.contains("<reply>") && content.contains("</reply>") {
813813+// let list_a = content.split("<reply>").collect::<Vec<&str>>();
814814+// let list_b = list_a[1].split("</reply>").collect::<Vec<&str>>();
815815+// list_b[0].to_owned()
816816+// } else {
817817+// "".to_owned()
818818+// }
819819+// }
739820740821fn extract_python(content: &str) -> String {
741822 if content.contains("<python>") && content.contains("</python>") {
···852933 let pi_process = Command::new(pi_exec_path)
853934 .arg("--mode")
854935 .arg("rpc")
855855- .arg("--no-session")
936936+ // .arg("--no-session")
856937 .env("PI_CODING_AGENT_DIR", pi_agent_dir)
857938 .env("PI_OFFLINE", "true")
858939 .stdin(Stdio::piped())
+1-1
tiles/src/utils/mod.rs
···88 SystemTime::now()
99 .duration_since(UNIX_EPOCH)
1010 .expect("time went backwards")
1111- .as_secs()
1111+ .as_millis() as u64
1212}
13131414pub fn test_logger() {