···2626use tilekit::modelfile::Role;
2727use tokio::time::sleep;
28282929+const MAX_LOAD_MODEL_RETRIES: u8 = 3;
3030+2931#[derive(Debug, Deserialize, Serialize, Clone)]
3032pub struct BenchmarkMetrics {
3133 ttft_ms: f64,
···238240 }
239241 // loading the model from mem-agent via daemon server
240242 let memory_path = get_memory_path().context("Setting/Retrieving memory_path failed")?;
241241- match load_model(&modelfile, &default_modelfile, &memory_path).await {
243243+ match load_model(&modelfile, &default_modelfile, &memory_path, 0).await {
242244 Ok(_) => start_repl(mlx_runtime, &modelfile, run_args, db_conn).await?,
243245 Err(err) => return Err(anyhow::anyhow!(err)),
244246 }
···400402 modelfile: &Modelfile,
401403 default_modelfile: &Modelfile,
402404 memory_path: &str,
405405+ retries: u8,
403406) -> Result<()> {
407407+ if retries > MAX_LOAD_MODEL_RETRIES {
408408+ return Err(anyhow!(
409409+ "Model loading retried failed after {} times",
410410+ retries
411411+ ));
412412+ }
404413 let model_name = modelfile.from.clone().unwrap();
414414+ let model_cache_res = get_model_cache(&model_name);
405415406406- if let Ok(model_cache_path) = get_model_cache(&model_name) {
407407- load_model_in_py(modelfile, default_modelfile, memory_path, &model_cache_path).await
408408- } else {
416416+ if model_cache_res.is_err() {
409417 download_model(&model_name).await?;
410410- let model_cache_path = get_model_cache(&model_name)?;
411411- load_model_in_py(modelfile, default_modelfile, memory_path, &model_cache_path).await
418418+ return Box::pin(load_model(modelfile, default_modelfile, memory_path, 0)).await;
419419+ }
420420+421421+ // If loading fails it most probably a partial downloaded
422422+ // model present, so we try to resume the download
423423+ if load_model_in_py(
424424+ modelfile,
425425+ default_modelfile,
426426+ memory_path,
427427+ &model_cache_res.unwrap(),
428428+ )
429429+ .await
430430+ .is_err()
431431+ {
432432+ log::warn!("Load model failed, resuming the partial download");
433433+ download_model(&model_name).await?;
434434+ Box::pin(load_model(
435435+ modelfile,
436436+ default_modelfile,
437437+ memory_path,
438438+ retries + 1,
439439+ ))
440440+ .await
441441+ } else {
442442+ Ok(())
412443 }
413444}
414445···635666 model_cache_path: &PathBuf,
636667) -> Result<()> {
637668 let client = Client::new();
638638- let model_name = modelfile.from.clone().unwrap();
669669+ let model_name = modelfile
670670+ .from
671671+ .clone()
672672+ .expect("Failed to get `FROM` of modelfile");
639673 let body = json!({
640674 "model": model_name,
641675 "memory_path": memory_path,
+1-1
tiles/src/utils/config.rs
···236236 Ok(())
237237}
238238239239-// Get the apt path where the model lies
239239+/// Get the apt path where the model in the system
240240pub fn get_model_cache(model_name: &str) -> Result<PathBuf> {
241241 let hf_model_dir = if model_name.starts_with("mlx-community/") {
242242 let model_spec_parts = model_name.split("/").collect::<Vec<&str>>();