Skip to content

Commit

Permalink
Add training triplets to the db
Browse files Browse the repository at this point in the history
  • Loading branch information
Polochon-street committed Sep 8, 2024
1 parent 2efd280 commit 1318a3f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 5 deletions.
29 changes: 24 additions & 5 deletions src/library.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ pub trait AppConfigTrait: Serialize + Sized + DeserializeOwned {
}
}

#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
/// The minimum configuration an application needs to work with
/// a [Library].
pub struct BaseConfig {
Expand All @@ -239,6 +239,7 @@ pub struct BaseConfig {
/// The number of CPU cores an analysis will be performed with.
/// Defaults to the number of CPUs in the user's computer.
number_cores: NonZeroUsize,
m: Vec<Vec<f32>>,
}

impl BaseConfig {
Expand Down Expand Up @@ -293,6 +294,7 @@ impl BaseConfig {
database_path,
features_version: FEATURES_VERSION,
number_cores,
m: vec![vec![3.]],
})
}
}
Expand Down Expand Up @@ -400,6 +402,23 @@ impl<Config: AppConfigTrait, D: ?Sized + DecoderTrait> Library<Config, D> {
alter table song rename column track_number_1 to track_number;
",
"alter table song add column disc_number integer;",
"
-- Training triplets used to do metric learning, in conjunction with
-- a human-processed survey. In this table, songs pointed to
-- by song_1_id and song_2_id are closer together than they
-- are to the song pointed to by odd_one_out_id, i.e.
-- d(s1, s2) < d(s1, odd_one_out) and d(s1, s2) < d(s2, odd_one_out)
create table training_triplet (
id integer primary key,
song_1_id integer not null,
song_2_id integer not null,
odd_one_out_id integer not null,
stamp timestamp default current_timestamp,
foreign key(song_1_id) references song(id) on delete cascade,
foreign key(song_2_id) references song(id) on delete cascade,
foreign key(odd_one_out_id) references song(id) on delete cascade
)
",
];

/// Create a new [Library] object from the given Config struct that
Expand Down Expand Up @@ -1444,7 +1463,7 @@ mod test {
metadata_bliss_does_not_have: String,
}

#[derive(Deserialize, Serialize, PartialEq, Eq, Debug, Clone)]
#[derive(Deserialize, Serialize, PartialEq, Debug, Clone)]
struct CustomConfig {
#[serde(flatten)]
base_config: BaseConfig,
Expand Down Expand Up @@ -3355,7 +3374,7 @@ mod test {
let version: u32 = sqlite_conn
.query_row("pragma user_version", [], |row| row.get(0))
.unwrap();
assert_eq!(version, 3);
assert_eq!(version, 4);
// Make sure we can call this over and over without any problem
Library::<BaseConfig, DummyDecoder>::new_from_base(
Some(config_dir.path().join("config.txt")),
Expand All @@ -3366,7 +3385,7 @@ mod test {
let version: u32 = sqlite_conn
.query_row("pragma user_version", [], |row| row.get(0))
.unwrap();
assert_eq!(version, 3);
assert_eq!(version, 4);
}

#[test]
Expand All @@ -3390,7 +3409,7 @@ mod test {
let version: u32 = sqlite_conn
.query_row("pragma user_version", [], |row| row.get(0))
.unwrap();
assert_eq!(version, 3);
assert_eq!(version, 4);
}

#[test]
Expand Down
63 changes: 63 additions & 0 deletions src/playlist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,69 @@ use ndarray_stats::QuantileExt;
use noisy_float::prelude::*;
use std::collections::HashMap;

const M_EUCLIDEAN: &[&[f32]] = &[
&[
1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
],
&[
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
],
];

/// Trait for creating a distance metric, measuring the distance to a set of vectors. If this
/// metric requires any kind of training, this should be done in the build function so that the
/// returned DistanceMetric instance is already trained and ready to use.
Expand Down

0 comments on commit 1318a3f

Please sign in to comment.