Skip to content

Commit

Permalink
Propagate io errors when loading models, panic later
Browse files Browse the repository at this point in the history
Propagate the errors to be able to handle them later in the CLI or
whatever downstream app that uses the crate.
  • Loading branch information
ZJaume committed Aug 28, 2024
1 parent 56a3b36 commit 0542590
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 36 deletions.
5 changes: 3 additions & 2 deletions src/identifier.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::BTreeMap;
use std::sync::Arc;
use std::io;

use ordered_float::OrderedFloat;
use strum::{IntoEnumIterator};
Expand Down Expand Up @@ -27,8 +28,8 @@ impl Identifier {
const PENALTY_VALUE : f32 = 7.0;
const MAX_NGRAM : usize = 6;

pub fn load(modelpath: &str) -> Self {
Self::new(Arc::new(Models::load(modelpath)))
pub fn load(modelpath: &str) -> io::Result<Self> {
Ok(Self::new(Arc::new(Models::load(modelpath)?)))
}

pub fn new(models: Arc<Models>) -> Self {
Expand Down
56 changes: 28 additions & 28 deletions src/languagemodel.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use std::collections::{HashMap, HashSet};
use std::hash::BuildHasherDefault;
use std::io::{Write, Read};
use std::io::{self, Write, Read};
use std::fs::{self, File};
use std::process::exit;
use std::path::Path;
use std::ops::Index;
use std::thread;

use strum::{IntoEnumIterator, Display, EnumCount};
use strum_macros::EnumIter;
use log::{debug, warn, error};
use log::{debug, warn};
use bitcode;

use wyhash2::WyHash;
Expand Down Expand Up @@ -128,25 +127,24 @@ impl Model {
}

// Create a new struct reading from a binary file
pub fn from_bin(p: &str) -> Self {
let mut file = File::open(p).expect(
format!("Error cannot open file {p:?}").as_str());
pub fn from_bin(p: &str) -> io::Result<Self> {
let mut file = File::open(p)?;
let mut content = Vec::new();
let _ = file.read_to_end(&mut content).unwrap();
let _ = file.read_to_end(&mut content)?;

bitcode::decode(&content).unwrap()
// should find a way to propagate possible bitcode errors?
Ok(bitcode::decode(&content).unwrap())
}

// Save the struct in binary format
// take ownership of the struct
pub fn save(self, p: &Path) {
pub fn save(self, p: &Path) -> io::Result<()> {
// Create file
let mut file = File::create(p).expect(
format!("Error cannot write to file {p:?}").as_str());
let mut file = File::create(p)?;

let serialized = bitcode::encode(&self);
// Write serialized bytes to the compressor
file.write_all(&serialized).expect("Error writing serialized model");
file.write_all(&serialized)
}
}

Expand All @@ -155,40 +153,42 @@ pub struct Models {
}

impl Models {
pub fn load(modelpath: &str) -> Self {
pub fn load(modelpath: &str) -> io::Result<Self> {
// Run a separated thread to load each model
// check model type is correct
let mut handles: Vec<thread::JoinHandle<_>> = Vec::new();
for model_type in ModelType::iter() {
let type_repr = model_type.to_string();
let filename = format!("{modelpath}/{type_repr}.bin");

// If a model does not exist, fail early
let path = Path::new(&filename);
if !path.exists() {
error!("Model file '{}' could not be found", filename);
let message = format!("Model file '{}' could not be found", filename);
for h in handles {
let _ = h.join();
let _ = h.join().unwrap()?;
}
exit(1);
return Err(io::Error::new(io::ErrorKind::NotFound, message));
}
handles.push(thread::spawn(move || {
let model = Model::from_bin(&filename);
let model = Model::from_bin(&filename)?;
// check model type is correct
assert!(model.model_type == model_type);
model
Ok::<Model, io::Error>(model)
}));
}

Self {
Ok(Self {
// remove first position because after removal, the vec is reindexed
inner: [
handles.remove(0).join().unwrap(),
handles.remove(0).join().unwrap(),
handles.remove(0).join().unwrap(),
handles.remove(0).join().unwrap(),
handles.remove(0).join().unwrap(),
handles.remove(0).join().unwrap(),
handles.remove(0).join().unwrap(),
handles.remove(0).join().unwrap()?,
handles.remove(0).join().unwrap()?,
handles.remove(0).join().unwrap()?,
handles.remove(0).join().unwrap()?,
handles.remove(0).join().unwrap()?,
handles.remove(0).join().unwrap()?,
handles.remove(0).join().unwrap()?,
]
}
})
}
}

Expand Down
13 changes: 7 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use target;

use crate::languagemodel::{Model, ModelType};
use crate::identifier::Identifier;
use crate::utils::Abort;


pub mod languagemodel;
Expand Down Expand Up @@ -42,13 +43,13 @@ pub struct PyIdentifier {
#[pymethods]
impl PyIdentifier {
#[new]
fn new() -> Self {
fn new() -> PyResult<Self> {
let modulepath = module_path().expect("Error loading python module path");
let identifier = Identifier::load(&modulepath);
let identifier = Identifier::load(&modulepath)?;

Self {
Ok(Self {
inner: identifier,
}
})
}

fn identify(&mut self, text: &str) -> String {
Expand All @@ -66,7 +67,7 @@ impl PyIdentifier {
pub fn cli_run() -> PyResult<()> {
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
let modulepath = module_path().expect("Error loading python module path");
let mut identifier = Identifier::load(&modulepath);
let mut identifier = Identifier::load(&modulepath).or_abort(1);

let stdin = io::stdin();

Expand Down Expand Up @@ -106,7 +107,7 @@ pub fn cli_compile() -> PyResult<()> {
let model = Model::from_text(modelpath, model_type);
let savepath = format!("{modulepath}/{type_repr}.bin");
info!("Saving {type_repr} model");
model.save(Path::new(&savepath));
model.save(Path::new(&savepath))?;
}
info!("Saved models at '{}'", modulepath);
info!("Finished");
Expand Down
17 changes: 17 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::process::{exit, Command};
use std::io;
use std::fs;

use log::{info, debug, error};
Expand Down Expand Up @@ -89,3 +90,19 @@ pub fn download_file_and_extract(url: &str, extractpath: &str) -> Result<(), Box
let runtime = Runtime::new()?;
runtime.block_on(download_file_and_extract_async(url, extractpath))
}

// Trait that extracts the contained ok value or aborts if error
// sending the error message to the log
pub trait Abort<T> {
fn or_abort(self, exit_code: i32) -> T;
}

impl<T> Abort<T> for io::Result<T>
{
fn or_abort(self, exit_code: i32) -> T {
match self {
Ok(v) => v,
Err(e) => { error!("{e}"); exit(exit_code); },
}
}
}

0 comments on commit 0542590

Please sign in to comment.