Skip to content

Commit

Permalink
cosine similarity bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasLoos committed Sep 19, 2024
1 parent bb941ca commit cd69f0b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 13 deletions.
11 changes: 4 additions & 7 deletions worker/pkg/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ function handleError(f, args) {
wasm.__wbindgen_exn_store(addHeapObject(e));
}
}
function __wbg_adapter_55(arg0, arg1, arg2, arg3) {
function __wbg_adapter_53(arg0, arg1, arg2, arg3) {
wasm.wasm_bindgen__convert__closures__invoke2_mut__h493f10fe887c5e99(arg0, arg1, addHeapObject(arg2), addHeapObject(arg3));
}

Expand Down Expand Up @@ -375,9 +375,6 @@ function __wbg_get_imports() {
const ret = getObject(arg0).arrayBuffer();
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbg_log_79d3c56888567995 = function(arg0) {
console.log(getObject(arg0));
};
imports.wbg.__wbg_newwithstrandinit_11fbc38beb4c26b0 = function() { return handleError(function (arg0, arg1, arg2) {
const ret = new Request(getStringFromWasm0(arg0, arg1), getObject(arg2));
return addHeapObject(ret);
Expand Down Expand Up @@ -450,7 +447,7 @@ function __wbg_get_imports() {
const a = state0.a;
state0.a = 0;
try {
return __wbg_adapter_55(a, state0.b, arg0, arg1);
return __wbg_adapter_53(a, state0.b, arg0, arg1);
} finally {
state0.a = a;
}
Expand Down Expand Up @@ -506,8 +503,8 @@ function __wbg_get_imports() {
const ret = wasm.memory;
return addHeapObject(ret);
};
imports.wbg.__wbindgen_closure_wrapper266 = function(arg0, arg1, arg2) {
const ret = makeMutClosure(arg0, arg1, 45, __wbg_adapter_20);
imports.wbg.__wbindgen_closure_wrapper236 = function(arg0, arg1, arg2) {
const ret = makeMutClosure(arg0, arg1, 44, __wbg_adapter_20);
return addHeapObject(ret);
};

Expand Down
Binary file modified worker/pkg/worker_bg.wasm
Binary file not shown.
8 changes: 4 additions & 4 deletions worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use std::collections::HashMap;
use half::f16;
use web_sys::{js_sys, Request, RequestInit, RequestMode, Response, console};
use web_sys::{js_sys, Request, RequestInit, RequestMode, Response};
// use web_sys::console
use std::cell::RefCell;
use js_sys::{ArrayBuffer, Uint8Array};
use console_error_panic_hook;
Expand Down Expand Up @@ -62,7 +63,7 @@ pub fn calc_similarities(
"cosine" => repr2
.axis_iter(Axis(0))
.enumerate()
.map(|(index, b)| b.dot(&a) / (norms1[[index]] * norms2[[index]]))
.map(|(index, b)| b.dot(&a) / (norms1[[row*n+col]] * norms2[[index]]))
.collect::<Vec<_>>(),
"cosine_centered" => repr2
.axis_iter(Axis(0))
Expand All @@ -71,7 +72,7 @@ pub fn calc_similarities(
Zip::from(a)
.and(b)
.and(means.as_ref().unwrap().view())
.fold(0.0, |acc, &ai, &bi, &mean| acc + (ai - mean) * (bi - mean)) / (norms1[[index]] * norms2[[index]]))
.fold(0.0, |acc, &ai, &bi, &mean| acc + (ai - mean) * (bi - mean)) / (norms1[[row*n+col]] * norms2[[index]]))
.collect::<Vec<_>>(),
"manhattan" => repr2
.axis_iter(Axis(0))
Expand Down Expand Up @@ -117,7 +118,6 @@ pub async fn fetch_repr(url: String, n: usize, m: usize) -> Result<(), JsValue>

// if the representation is already fetched, return
if GLOBAL_MAP.with(|map| map.borrow().contains_key(&url)) {
console::log_1(&JsValue::from_str(&format!("WASM: Representation already fetched: {}", url)));
return Ok(());
}

Expand Down
4 changes: 2 additions & 2 deletions worker/webworker.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import * as wasm from './pkg/worker.js';

const asdf = wasm.default().then(_ => null);
const wasmInit = wasm.default().then(_ => null);

onmessage = function(e) {
asdf.then(_ => {
wasmInit.then(_ => {
const id = e.data.id;
if (e.data.task === 'fetch_repr') {
const { url, n, m } = e.data.data;
Expand Down

0 comments on commit cd69f0b

Please sign in to comment.