Skip to content

Commit

Permalink
split representations by stesps, i.e. fix #1
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasLoos committed Mar 11, 2024
1 parent a550353 commit 1c06a72
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 51 deletions.
5 changes: 4 additions & 1 deletion generate-reprs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"metadata": {},
"outputs": [],
"source": [
"from diffusers import AutoPipelineForText2Image, AutoencoderKL\n",
"from diffusers import AutoPipelineForText2Image\n",
"import torch\n",
"import numpy as np\n",
"from pathlib import Path\n",
Expand Down Expand Up @@ -154,6 +154,9 @@
"\n",
" # save representations\n",
" for pos, reprs in representations.items():\n",
" for j, repr in enumerate(reprs, 0):\n",
" with open(save_path / f'repr-{pos}-{j}.bin', 'wb') as f:\n",
" f.write(np.array(repr, dtype=np.float16).tobytes())\n",
" with open(save_path / f'repr-{pos}.bin', 'wb') as f:\n",
" f.write(np.array(np.stack(reprs), dtype=np.float16).tobytes())\n",
"\n",
Expand Down
13 changes: 5 additions & 8 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>
concept.image_ctx = concept.image_canvas.getContext('2d');
concept.tile_ctx = concept.tile_canvas.getContext('2d');

concept.getUrl = () => `${window.location.pathname.split('/').slice(0, -1).join('/')}/representations/${current_model.short}/${concept.name}/repr-${current_position}.bin`;
concept.getUrl = () => `${window.location.pathname.split('/').slice(0, -1).join('/')}/representations/${current_model.short}/${concept.name}/repr-${current_position}-${concept.step-1}.bin`;

// helper function to load the representation
const getRepr = () => {
concept.repr_loading_started = Date.now();
const { steps, n, m } = current_model.getShapes();
callWasm('fetch_repr', { url: concept.getUrl(), steps, n, m })
callWasm('fetch_repr', { url: concept.getUrl(), n, m })
.then(() => {
console.log(`Fetched repr for ${concept.name} from ${concept.getUrl()}`);
updateCanvasesWithLastClicked();
Expand Down Expand Up @@ -310,11 +310,8 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>

// slider event listener
slider.addEventListener('input', event => {
const value = event.target.value;
slider_value.textContent = value;
concept.step = value;
concept.img.src = `representations/${current_model.short}/${concept.name}/${concept.step}.jpg`;
updateCanvasesWithLastClicked();
concept.step = event.target.value;
concept.update();
});

// concept name select event listener
Expand Down Expand Up @@ -432,7 +429,7 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>
img_ctx.fillRect(last_clicked.col * tile_size, last_clicked.row * tile_size, tile_size, tile_size);
}
// update tile canvas
callWasm('calc_similarities', { func, repr1_str: base_concept.getUrl(), repr2_str: concept.getUrl(), step1: base_concept.step-1, step2: concept.step-1, row, col })
callWasm('calc_similarities', { func, repr1_str: base_concept.getUrl(), repr2_str: concept.getUrl(), row, col })
.then((similarities) => {
if (similarities === 'loading') {
drawError(concept, 'Loading...');
Expand Down
20 changes: 7 additions & 13 deletions worker/pkg/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,11 @@ function getArrayF32FromWasm0(ptr, len) {
* @param {string} func
* @param {string} repr1_str
* @param {string} repr2_str
* @param {number} step1
* @param {number} step2
* @param {number} row
* @param {number} col
* @returns {Float32Array}
*/
export function calc_similarities(func, repr1_str, repr2_str, step1, step2, row, col) {
export function calc_similarities(func, repr1_str, repr2_str, row, col) {
try {
const retptr = wasm.__wbindgen_add_to_stack_pointer(-16);
const ptr0 = passStringToWasm0(func, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
Expand All @@ -243,7 +241,7 @@ export function calc_similarities(func, repr1_str, repr2_str, step1, step2, row,
const len1 = WASM_VECTOR_LEN;
const ptr2 = passStringToWasm0(repr2_str, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
const len2 = WASM_VECTOR_LEN;
wasm.calc_similarities(retptr, ptr0, len0, ptr1, len1, ptr2, len2, step1, step2, row, col);
wasm.calc_similarities(retptr, ptr0, len0, ptr1, len1, ptr2, len2, row, col);
var r0 = getInt32Memory0()[retptr / 4 + 0];
var r1 = getInt32Memory0()[retptr / 4 + 1];
var r2 = getInt32Memory0()[retptr / 4 + 2];
Expand All @@ -261,15 +259,14 @@ export function calc_similarities(func, repr1_str, repr2_str, step1, step2, row,

/**
* @param {string} url
* @param {number} steps
* @param {number} n
* @param {number} m
* @returns {Promise<void>}
*/
export function fetch_repr(url, steps, n, m) {
export function fetch_repr(url, n, m) {
const ptr0 = passStringToWasm0(url, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
const len0 = WASM_VECTOR_LEN;
const ret = wasm.fetch_repr(ptr0, len0, steps, n, m);
const ret = wasm.fetch_repr(ptr0, len0, n, m);
return takeObject(ret);
}

Expand All @@ -280,7 +277,7 @@ function handleError(f, args) {
wasm.__wbindgen_exn_store(addHeapObject(e));
}
}
function __wbg_adapter_57(arg0, arg1, arg2, arg3) {
function __wbg_adapter_55(arg0, arg1, arg2, arg3) {
wasm.wasm_bindgen__convert__closures__invoke2_mut__h493f10fe887c5e99(arg0, arg1, addHeapObject(arg2), addHeapObject(arg3));
}

Expand Down Expand Up @@ -381,9 +378,6 @@ function __wbg_get_imports() {
imports.wbg.__wbg_log_79d3c56888567995 = function(arg0) {
console.log(getObject(arg0));
};
imports.wbg.__wbg_warn_2a68e3ab54e55f28 = function(arg0) {
console.warn(getObject(arg0));
};
imports.wbg.__wbg_newwithstrandinit_11fbc38beb4c26b0 = function() { return handleError(function (arg0, arg1, arg2) {
const ret = new Request(getStringFromWasm0(arg0, arg1), getObject(arg2));
return addHeapObject(ret);
Expand Down Expand Up @@ -456,7 +450,7 @@ function __wbg_get_imports() {
const a = state0.a;
state0.a = 0;
try {
return __wbg_adapter_57(a, state0.b, arg0, arg1);
return __wbg_adapter_55(a, state0.b, arg0, arg1);
} finally {
state0.a = a;
}
Expand Down Expand Up @@ -512,7 +506,7 @@ function __wbg_get_imports() {
const ret = wasm.memory;
return addHeapObject(ret);
};
imports.wbg.__wbindgen_closure_wrapper274 = function(arg0, arg1, arg2) {
imports.wbg.__wbindgen_closure_wrapper266 = function(arg0, arg1, arg2) {
const ret = makeMutClosure(arg0, arg1, 45, __wbg_adapter_20);
return addHeapObject(ret);
};
Expand Down
Binary file modified worker/pkg/worker_bg.wasm
Binary file not shown.
43 changes: 18 additions & 25 deletions worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::cell::RefCell;
use js_sys::{ArrayBuffer, Uint8Array};
use console_error_panic_hook;
use std::panic;
use ndarray::{s, Array2, Array3, ArrayView1, Axis, Zip};
use ndarray::{s, Array1, Array2, ArrayView1, Axis, Zip};
use std::sync::Arc;


Expand All @@ -17,7 +17,7 @@ macro_rules! jsnone {() => {JsValue::from_str(&format!("WASM: unexpected None in

// cache to store representations and means
thread_local! {
static GLOBAL_MAP: RefCell<HashMap<String, Arc<(Array3<f32>,Array2<f32>,Array2<f32>)>>> = RefCell::new(HashMap::new());
static GLOBAL_MAP: RefCell<HashMap<String, Arc<(Array2<f32>,Array1<f32>,Array1<f32>)>>> = RefCell::new(HashMap::new());
}


Expand All @@ -27,8 +27,6 @@ pub fn calc_similarities(
func: String,
repr1_str: String,
repr2_str: String,
step1: usize,
step2: usize,
row: usize,
col: usize,
) -> Result<Vec<f32>, JsValue> {
Expand All @@ -45,45 +43,45 @@ pub fn calc_similarities(
let arc_data1 = reprs.get(&repr1_str).ok_or_else(|| JsValue::from_str("loading"))?;
let arc_data2 = reprs.get(&repr2_str).ok_or_else(|| JsValue::from_str("loading"))?;
let (repr1, repr2, means1_full, means2_full, norms1, norms2) = (&arc_data1.0, &arc_data2.0, &arc_data1.1, &arc_data2.1, &arc_data1.2, &arc_data2.2);
let n = (repr1.shape()[1] as f32).sqrt() as usize;
let n = (repr1.shape()[0] as f32).sqrt() as usize;

// calculate mean of bath representations
let means = if func == "cosine_centered" {
Some(
Zip::from(&means1_full.slice(s![step1,..]))
.and(&means2_full.slice(s![step2,..]))
Zip::from(means1_full)
.and(means2_full)
.map_collect(|&mean1, &mean2| ((mean1 + mean2) / 2.0)))
} else { None };

// log time for debugging
// let time_after_loading = js_sys::Date::now();

// calculate similarities
let a: ArrayView1<f32> = repr1.slice(s![step1,row*n+col,..]);
let a: ArrayView1<f32> = repr1.slice(s![row*n+col,..]);
let mut similarities: Vec<f32> = match func.as_str() {
"cosine" => repr2.slice(s![step2, .., ..])
"cosine" => repr2
.axis_iter(Axis(0))
.enumerate()
.map(|(index, b)| b.dot(&a) / (norms1[[step1, index]] * norms2[[step2, index]]))
.map(|(index, b)| b.dot(&a) / (norms1[[index]] * norms2[[index]]))
.collect::<Vec<_>>(),
"cosine_centered" => repr2.slice(s![step2,..,..])
"cosine_centered" => repr2
.axis_iter(Axis(0))
.enumerate()
.map(|(index, b)|
Zip::from(a)
.and(b)
.and(means.as_ref().unwrap().view())
.fold(0.0, |acc, &ai, &bi, &mean| acc + (ai - mean) * (bi - mean)) / (norms1[[step1, index]] * norms2[[step2, index]]))
.fold(0.0, |acc, &ai, &bi, &mean| acc + (ai - mean) * (bi - mean)) / (norms1[[index]] * norms2[[index]]))
.collect::<Vec<_>>(),
"manhattan" => repr2.slice(s![step2,..,..])
"manhattan" => repr2
.axis_iter(Axis(0))
.map(|b| Zip::from(a).and(b).fold(0.0, |acc, &ai, &bi| acc + (ai - bi).abs()))
.collect::<Vec<_>>(),
"euclidean" => repr2.slice(s![step2,..,..])
"euclidean" => repr2
.axis_iter(Axis(0))
.map(|b| Zip::from(a).and(b).fold(0.0, |acc, &ai, &bi| acc + (ai - bi).powi(2)).sqrt())
.collect::<Vec<_>>(),
"chebyshev" => repr2.slice(s![step2,..,..])
"chebyshev" => repr2
.axis_iter(Axis(0))
.map(|b| Zip::from(a).and(b).fold(0.0, |acc: f32, &ai, &bi| acc.max((ai - bi).abs())))
.collect::<Vec<_>>(),
Expand Down Expand Up @@ -114,7 +112,7 @@ pub fn calc_similarities(

// fetch representation from url and store it in cache
#[wasm_bindgen]
pub async fn fetch_repr(url: String, steps: usize, n: usize, m: usize) -> Result<(), JsValue> {
pub async fn fetch_repr(url: String, n: usize, m: usize) -> Result<(), JsValue> {
console_error_panic_hook::set_once(); // better error messages in the console

// if the representation is already fetched, return
Expand Down Expand Up @@ -145,21 +143,16 @@ pub async fn fetch_repr(url: String, steps: usize, n: usize, m: usize) -> Result
let float16_data: Vec<f16> = bytes.chunks_exact(2).map(|chunk| f16::from_le_bytes([chunk[0], chunk[1]])).collect();

// convert float16 vector to Array4<f32>
let representations = match Array3::from_shape_vec((steps, n*n, m), float16_data.iter().map(|&x| f32::from(x)).collect()) {
let representations = match Array2::from_shape_vec((n*n, m), float16_data.iter().map(|&x| f32::from(x)).collect()) {
Ok(repr) => repr,
Err(e) => {
let new_steps = float16_data.len() / (n*n*m);
if new_steps * n*n*m != float16_data.len() {
return Err(JsValue::from_str(format!("Failed to convert float16 vector (len {}, {}) to Array3<f32> with shape ({}, {}, {}): {:#?}", float16_data.len(), url, steps, n*n, m, e).as_str()))
}
console::warn_1(&JsValue::from_str(format!("Failed to convert float16 vector (len {}, {}) to Array3<f32> with shape ({}, {}, {}), using shape ({}, {}, {}) instead", float16_data.len(), url, steps, n*n, m, new_steps, n*n, m).as_str()));
Array3::from_shape_vec((new_steps, n*n, m), float16_data.iter().map(|&x| f32::from(x)).collect()).map_err(jserr!())?
return Err(JsValue::from_str(format!("Failed to convert float16 vector (len {}, {}) to Array3<f32> with shape ({}, {}): {:#?}", float16_data.len(), url, n*n, m, e).as_str()))
}
};

// store representations and means and norms in cache
let means = representations.mean_axis(Axis(1)).ok_or(jsnone!())?;
let norms = representations.mapv(|x| x.powi(2)).sum_axis(Axis(2)).mapv(f32::sqrt);
let means = representations.mean_axis(Axis(0)).ok_or(jsnone!())?;
let norms = representations.mapv(|x| x.powi(2)).sum_axis(Axis(1)).mapv(f32::sqrt);
GLOBAL_MAP.with(|map| {
map.borrow_mut().insert(url.to_string(), Arc::new((representations, means, norms)));
});
Expand Down
8 changes: 4 additions & 4 deletions worker/webworker.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ onmessage = function(e) {
asdf.then(_ => {
const id = e.data.id;
if (e.data.task === 'fetch_repr') {
const { url, steps, n, m } = e.data.data;
wasm.fetch_repr(url, steps, n, m)
const { url, n, m } = e.data.data;
wasm.fetch_repr(url, n, m)
.then(() => postMessage({id, data: {status: 'success'}}))
.catch((e) => postMessage({id, data: {status: 'error', msg: e}}))
} else if (e.data.task === 'calc_similarities') {
const { func, repr1_str, repr2_str, step1, step2, row, col } = e.data.data;
const { func, repr1_str, repr2_str, row, col } = e.data.data;
try {
const similarities = wasm.calc_similarities(func, repr1_str, repr2_str, step1, step2, row, col);
const similarities = wasm.calc_similarities(func, repr1_str, repr2_str, row, col);
postMessage({ id, data: similarities });
} catch (e) {
if (e === 'loading') {
Expand Down

0 comments on commit 1c06a72

Please sign in to comment.