-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #68 from bio-ontology-research-group/box2el
EL Geometric Models
- Loading branch information
Showing
24 changed files
with
476 additions
and
238 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from mowl.evaluation.base import AxiomsRankBasedEvaluator | ||
from mowl.projection.factory import projector_factory | ||
from mowl.projection.edge import Edge | ||
import logging | ||
import numpy as np | ||
from scipy.stats import rankdata | ||
import torch as th | ||
|
||
|
||
class BoxSquaredELPPIEvaluator(AxiomsRankBasedEvaluator): | ||
|
||
def __init__( | ||
self, | ||
axioms, | ||
eval_method, | ||
axioms_to_filter, | ||
class_name_indexemb, | ||
rel_name_indexemb, | ||
device="cpu", | ||
verbose=False | ||
): | ||
|
||
super().__init__(axioms, eval_method, axioms_to_filter, device, verbose) | ||
|
||
self.class_name_indexemb = class_name_indexemb | ||
self.relation_name_indexemb = rel_name_indexemb | ||
|
||
self._loaded_training_scores = False | ||
self._loaded_eval_data = False | ||
self._loaded_ht_data = False | ||
|
||
def _load_head_tail_entities(self): | ||
if self._loaded_ht_data: | ||
return | ||
|
||
ents, _ = Edge.getEntitiesAndRelations(self.axioms) | ||
ents_filter, _ = Edge.getEntitiesAndRelations(self.axioms_to_filter) | ||
|
||
entities = list(set(ents) | set(ents_filter)) | ||
|
||
self.head_entities = set() | ||
for e in entities: | ||
if e in self.class_name_indexemb: | ||
self.head_entities.add(e) | ||
else: | ||
logging.info("Entity %s not present in the embeddings dictionary. Ignoring it.", e) | ||
|
||
self.tail_entities = set() | ||
for e in entities: | ||
if e in self.class_name_indexemb: | ||
self.tail_entities.add(e) | ||
else: | ||
logging.info("Entity %s not present in the embeddings dictionary. Ignoring it.", e) | ||
|
||
self.head_name_indexemb = {k: self.class_name_indexemb[k] for k in self.head_entities} | ||
self.tail_name_indexemb = {k: self.class_name_indexemb[k] for k in self.tail_entities} | ||
|
||
self.head_indexemb_indexsc = {v: k for k, v in enumerate(self.head_name_indexemb.values())} | ||
self.tail_indexemb_indexsc = {v: k for k, v in enumerate(self.tail_name_indexemb.values())} | ||
|
||
self._loaded_ht_data = True | ||
|
||
def _load_training_scores(self): | ||
if self._loaded_training_scores: | ||
return self.training_scores | ||
|
||
self._load_head_tail_entities() | ||
|
||
training_scores = np.ones((len(self.head_entities), len(self.tail_entities)), | ||
dtype=np.int32) | ||
|
||
if self._compute_filtered_metrics: | ||
# careful here: c must be in head entities and d must be in tail entities | ||
for axiom in self.axioms_to_filter: | ||
c, _, d = axiom.astuple() | ||
if (c not in self.head_entities) or not (d in self.tail_entities): | ||
continue | ||
|
||
c, d = self.head_name_indexemb[c], self.tail_name_indexemb[d] | ||
c, d = self.head_indexemb_indexsc[c], self.tail_indexemb_indexsc[d] | ||
|
||
training_scores[c, d] = 10000 | ||
|
||
logging.info("Training scores created") | ||
|
||
self._loaded_training_scores = True | ||
return training_scores | ||
|
||
def _init_axioms(self, axioms): | ||
|
||
if axioms is None: | ||
return None | ||
|
||
projector = projector_factory("taxonomy_rels", relations=["http://interacts_with"]) | ||
|
||
edges = projector.project(axioms) | ||
return edges # List of Edges | ||
|
||
def compute_axiom_rank(self, axiom): | ||
|
||
self.training_scores = self._load_training_scores() | ||
|
||
c, r, d = axiom.astuple() | ||
|
||
if not (c in self.head_entities) or not (d in self.tail_entities): | ||
return None, None, None | ||
|
||
# Embedding indices | ||
c_emb_idx, d_emb_idx = self.head_name_indexemb[c], self.tail_name_indexemb[d] | ||
|
||
# Scores matrix labels | ||
c_sc_idx, d_sc_idx = self.head_indexemb_indexsc[c_emb_idx], | ||
self.tail_indexemb_indexsc[d_emb_idx] | ||
|
||
r = self.relation_name_indexemb[r] | ||
|
||
data = th.tensor([ | ||
[c_emb_idx, r, self.tail_name_indexemb[x]] for x in | ||
self.tail_entities]).to(self.device) | ||
|
||
res = self.eval_method(data).squeeze().cpu().detach().numpy() | ||
|
||
# self.testing_predictions[c_sc_idx, :] = res | ||
index = rankdata(res, method='average') | ||
rank = index[d_sc_idx] | ||
|
||
findex = rankdata((res * self.training_scores[c_sc_idx, :]), method='average') | ||
frank = findex[d_sc_idx] | ||
|
||
return rank, frank, len(self.tail_entities) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
|
||
from mowl.nn import BoxSquaredELModule | ||
from mowl.base_models.elmodel import EmbeddingELModel | ||
from mowl.models.boxsquaredel.evaluate import BoxSquaredELPPIEvaluator | ||
import torch as th | ||
from torch import nn | ||
|
||
|
||
class BoxSquaredEL(EmbeddingELModel): | ||
""" | ||
Implementation based on [peng2020]_. | ||
""" | ||
|
||
def __init__(self, | ||
dataset, | ||
embed_dim=50, | ||
margin=0.02, | ||
reg_norm=1, | ||
learning_rate=0.001, | ||
epochs=1000, | ||
batch_size=4096 * 8, | ||
delta=2.5, | ||
reg_factor=0.2, | ||
num_negs=4, | ||
model_filepath=None, | ||
device='cpu' | ||
): | ||
super().__init__(dataset, embed_dim, batch_size, extended=True, model_filepath=model_filepath) | ||
|
||
|
||
self.margin = margin | ||
self.reg_norm = reg_norm | ||
self.delta = delta | ||
self.reg_factor = reg_factor | ||
self.num_negs = num_negs | ||
self.learning_rate = learning_rate | ||
self.epochs = epochs | ||
self.device = device | ||
self._loaded = False | ||
self.extended = False | ||
self.init_module() | ||
|
||
def init_module(self): | ||
self.module = BoxSquaredELModule( | ||
len(self.class_index_dict), | ||
len(self.object_property_index_dict), | ||
embed_dim=self.embed_dim, | ||
gamma=self.margin, | ||
delta=self.delta, | ||
reg_factor=self.reg_factor | ||
|
||
).to(self.device) | ||
|
||
def train(self): | ||
raise NotImplementedError | ||
|
||
|
||
def eval_method(self, data): | ||
return self.module.gci2_score(data) | ||
|
||
def get_embeddings(self): | ||
self.init_module() | ||
|
||
print('Load the best model', self.model_filepath) | ||
self.load_best_model() | ||
|
||
ent_embeds = {k: v for k, v in zip(self.class_index_dict.keys(), | ||
self.module.class_embed.weight.cpu().detach().numpy())} | ||
rel_embeds = {k: v for k, v in zip(self.object_property_index_dict.keys(), | ||
self.module.rel_embed.weight.cpu().detach().numpy())} | ||
return ent_embeds, rel_embeds | ||
|
||
def load_best_model(self): | ||
self.init_module() | ||
self.module.load_state_dict(th.load(self.model_filepath)) | ||
self.module.eval() | ||
|
Oops, something went wrong.