Skip to content

Commit

Permalink
fixed some bugs and better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasLoos committed Mar 3, 2024
1 parent 59ae19d commit 49fffce
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 41 deletions.
18 changes: 8 additions & 10 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>
worker.jobs = {};
worker.job_counter = 0;
worker.last_calc_promise = new Promise((resolve, reject) => {resolve()});
worker.last_fetch_repr_promise = new Promise((resolve, reject) => {resolve()});
worker.onmessage = (e) => {
const { id, data } = e.data;
const job = worker.jobs[id];
Expand All @@ -144,6 +143,9 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>
} else {
console.warn('Received message for unknown job:', e.data);
}
if (data.status === 'error') {
console.error('Error in worker:', data.msg);
}
};
worker.onerror = (e) => {
console.error('Error in worker:', e);
Expand All @@ -170,13 +172,9 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>
});
return worker.last_calc_promise;
} else if (task === 'fetch_repr') {
worker.last_fetch_repr_promise = worker.last_fetch_repr_promise.then(() => {
console.log('Fetching repr:', id, task, data.url);
worker.jobs[id] = job;
worker.postMessage({ id, task, data });
return promise;
});
return worker.last_fetch_repr_promise;
worker.jobs[id] = job;
worker.postMessage({ id, task, data });
return promise;
} else {
console.error('Unknown task:', task);
}
Expand Down Expand Up @@ -417,7 +415,7 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>
try {
img_ctx.drawImage(concept.img, 0, 0, img_ctx.canvas.width, img_ctx.canvas.height);
} catch (error) {
console.warn('Error while drawing image');
console.warn('Error while drawing image: ', error.toString());
}
drawGrid(img_ctx);
if (concept === base_concept) {
Expand Down Expand Up @@ -448,7 +446,7 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>
})
.catch((error) => {
// console.error('Error while calculating similarities:', error);
drawError(concept, 'Error calculating similarities...\nTry to change settings.');
drawError(concept, 'Error calculating similarities...\nMaybe try to change settings.');
updateCanvasesSoon(base_concept, col, row); // schedule update to check if representations are loaded and animate loading text
});
}
Expand Down
2 changes: 1 addition & 1 deletion worker/pkg/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ function __wbg_get_imports() {
return addHeapObject(ret);
};
imports.wbg.__wbindgen_closure_wrapper274 = function(arg0, arg1, arg2) {
const ret = makeMutClosure(arg0, arg1, 46, __wbg_adapter_20);
const ret = makeMutClosure(arg0, arg1, 45, __wbg_adapter_20);
return addHeapObject(ret);
};

Expand Down
Binary file modified worker/pkg/worker_bg.wasm
Binary file not shown.
54 changes: 28 additions & 26 deletions worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ use ndarray::{s, Array2, Array3, ArrayView1, Axis, Zip};
use std::sync::Arc;


macro_rules! jserr {() => {|e| JsValue::from_str(&format!("WASM: {:#?}", e))};}
macro_rules! jsnone {() => {JsValue::from_str(&format!("WASM: unexpected None in {}:{}:{}", file!(), line!(), column!()))};}


// cache to store representations and means
thread_local! {
static GLOBAL_MAP: RefCell<HashMap<String, Arc<(Array3<f32>,Array2<f32>,Array2<f32>)>>> = RefCell::new(HashMap::new());
Expand Down Expand Up @@ -38,8 +42,8 @@ pub fn calc_similarities(
// get representations from cache
GLOBAL_MAP.with(|map| {
let reprs = map.borrow();
let arc_data1 = reprs.get(&repr1_str).ok_or_else(|| JsValue::from_str(&format!("Failed to get representation, url: {}", repr1_str)))?;
let arc_data2 = reprs.get(&repr2_str).ok_or_else(|| JsValue::from_str(&format!("Failed to get representation, url: {}", repr2_str)))?;
let arc_data1 = reprs.get(&repr1_str).ok_or_else(|| JsValue::from_str("loading"))?;
let arc_data2 = reprs.get(&repr2_str).ok_or_else(|| JsValue::from_str("loading"))?;
let (repr1, repr2, means1_full, means2_full, norms1, norms2) = (&arc_data1.0, &arc_data2.0, &arc_data1.1, &arc_data2.1, &arc_data1.2, &arc_data2.2);
let n = (repr1.shape()[1] as f32).sqrt() as usize;

Expand Down Expand Up @@ -91,7 +95,7 @@ pub fn calc_similarities(

// normalize distances
if func == "euclidean" || func == "manhattan" || func == "chebyshev" {
let max_distance = *similarities.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
let max_distance = *similarities.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).ok_or(jsnone!())?;
for distance in similarities.iter_mut() {
*distance = 1.0 - (*distance / max_distance);
}
Expand Down Expand Up @@ -127,36 +131,34 @@ pub async fn fetch_repr(url: String, steps: usize, n: usize, m: usize) -> Result

// fetch representation
let global = js_sys::global().unchecked_into::<web_sys::WorkerGlobalScope>();
let resp_value = match JsFuture::from(global.fetch_with_request(&request)).await {
Ok(value) => value,
Err(e) => {
console::warn_1(&JsValue::from_str(&format!("WASM: Fetch error: {:?}", e)));
return Err(JsValue::from_str(&format!("Failed to fetch representation: {}", url)))
}
};
let resp: Response = match resp_value.dyn_into() {
Ok(resp) => resp,
Err(_) => {
console::warn_1(&JsValue::from_str("WASM: Failed to get Response: resp_value.dyn_into() failed"));
return Err(JsValue::from_str("Failed to get Response: resp_value.dyn_into() failed"));
}
let resp: Response = match JsFuture::from(global.fetch_with_request(&request)).await {
Ok(value) => value.dyn_into().map_err(jserr!())?,
Err(e) => return Err(JsValue::from_str(&format!("Failed to fetch representation ({}): {:#?}", url, e)))
};

// convert response to float16 vector
let buffer_value = JsFuture::from(resp.array_buffer()?).await?;
let buffer: ArrayBuffer = match buffer_value.dyn_into() {
Ok(buffer) => buffer,
Err(_) => return Err(JsValue::from_str("Failed to get ArrayBuffer: buffer_value.dyn_into() failed")),
};
let buffer: ArrayBuffer = JsFuture::from(resp.array_buffer()?).await?.dyn_into().map_err(jserr!())?;
if buffer.byte_length() % 2 != 0 {
return Err(JsValue::from_str("Buffer length is not a multiple of 2 (for float16)"));
return Err(JsValue::from_str(&format!("Buffer length is not a multiple of 2 (for float16): {}", url)));
}
let bytes = Uint8Array::new(&buffer).to_vec();
let float16_data: Vec<f16> = bytes.chunks(2).map(|chunk| f16::from_le_bytes([chunk[0], chunk[1]])).collect();
let float16_data: Vec<f16> = bytes.chunks_exact(2).map(|chunk| f16::from_le_bytes([chunk[0], chunk[1]])).collect();

// convert float16 vector to Array4<f32>
let representations = match Array3::from_shape_vec((steps, n*n, m), float16_data.iter().map(|&x| f32::from(x)).collect()) {
Ok(repr) => repr,
Err(e) => {
let new_steps = float16_data.len() / (n*n*m);
if new_steps * n*n*m != float16_data.len() {
return Err(JsValue::from_str(format!("Failed to convert float16 vector (len {}, {}) to Array3<f32> with shape ({}, {}, {}): {:#?}", float16_data.len(), url, steps, n*n, m, e).as_str()))
}
console::warn_1(&JsValue::from_str(format!("Failed to convert float16 vector (len {}, {}) to Array3<f32> with shape ({}, {}, {}), using shape ({}, {}, {}) instead", float16_data.len(), url, steps, n*n, m, new_steps, n*n, m).as_str()));
Array3::from_shape_vec((new_steps, n*n, m), float16_data.iter().map(|&x| f32::from(x)).collect()).map_err(jserr!())?
}
};

// convert float16 vector to Array4<f32> and store it in cache
let representations = Array3::from_shape_vec((steps, n*n, m), float16_data.iter().map(|&x| f32::from(x)).collect()).unwrap();
let means = representations.mean_axis(Axis(1)).unwrap();
// store representations and means and norms in cache
let means = representations.mean_axis(Axis(1)).ok_or(jsnone!())?;
let norms = representations.mapv(|x| x.powi(2)).sum_axis(Axis(2)).mapv(f32::sqrt);
GLOBAL_MAP.with(|map| {
map.borrow_mut().insert(url.to_string(), Arc::new((representations, means, norms)));
Expand Down
10 changes: 6 additions & 4 deletions worker/webworker.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@ onmessage = function(e) {
wasm.fetch_repr(url, steps, n, m)
.then(() => postMessage({id, data: {status: 'success'}}))
.catch((e) => postMessage({id, data: {status: 'error', msg: e}}))
.finally(() => console.log('finished fetching repr'));
} else if (e.data.task === 'calc_similarities') {
const { func, repr1_str, repr2_str, step1, step2, row, col } = e.data.data;
try {
const similarities = wasm.calc_similarities(func, repr1_str, repr2_str, step1, step2, row, col);
postMessage({ id, data: similarities });
} catch (e) {
// assume that the error is due to the loading of the representations
console.log('error in calc_similarities, assuming representaitons are loading\n', e);
postMessage({ id, data: 'loading' });
if (e === 'loading') {
postMessage({ id, data: 'loading' });
} else {
console.error('error in calc_similarities\n', e);
postMessage({ id, data: 'error' });
}
}
}
})
Expand Down

0 comments on commit 49fffce

Please sign in to comment.