Skip to content

Commit

Permalink
Crash fix for TalkingHead evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
radekd91 committed Dec 3, 2023
1 parent 1eb523c commit a2f5a0a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
10 changes: 10 additions & 0 deletions inferno_apps/EmotionRecognition/training/train_emodeca.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from inferno.models.external.EmoDeep3DFace import EmoDeep3DFace
except ImportError as e:
print("Could not import EmoDeep3DFace")
except OSError as e:
print("Could not import EmoDeep3DFace")
# warning: external import collision
# try:
# from inferno.models.external.Emo3DDFA_V2 import Emo3DDFA_v2
Expand Down Expand Up @@ -54,6 +56,12 @@
from inferno_apps.EMOCA.utils.load import hack_paths


try:
from inferno.sandbox.infernal.models.EmotionRecognition.EmoFocus import EmoFocus
except ImportError as e:
print("Could not import EmoShapeCycle")


# project_name = 'EmoDECA'


Expand Down Expand Up @@ -189,6 +197,8 @@ def single_stage_deca_pass(deca, cfg, stage, prefix, dm=None, logger=None,
from inferno.models.external.EmoDeep3DFace import EmoDeep3DFace
except ImportError as e:
print("Could not import EmoDeep3DFace")
except OSError as e:
print("Could not import EmoDeep3DFace")
if cfg.model.emodeca_type == 'Emo3DDFA_v2':
## ugly and yucky import but otherwise there's import collisions with EmoDeep3DFace
try:
Expand Down
14 changes: 8 additions & 6 deletions inferno_apps/TalkingHead/evaluation/evaluation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import soundfile as sf
from psbody.mesh import Mesh
from inferno.utils.other import get_path_to_assets
import omegaconf


def create_condition(talking_head, sample, emotions=None, intensities=None, identities=None):
Expand Down Expand Up @@ -150,13 +151,14 @@ def create_base_sample(talking_head, audio_path, smallest_unit=1, silent_frames_
if silence_all:
sample["raw_audio"] = np.zeros_like(sample["raw_audio"])
T = sample["raw_audio"].shape[0]
reconstruction_type = talking_head.cfg.data.reconstruction_type[0] if isinstance(talking_head.cfg.data.reconstruction_type, (list, omegaconf.ListConfig)) else talking_head.cfg.data.reconstruction_type
sample["reconstruction"] = {}
sample["reconstruction"][talking_head.cfg.data.reconstruction_type[0]] = {}
sample["reconstruction"][talking_head.cfg.data.reconstruction_type[0]]["gt_exp"] = np.zeros((T, 50), dtype=np.float32)
sample["reconstruction"][talking_head.cfg.data.reconstruction_type[0]]["gt_shape"] = np.zeros((300), dtype=np.float32)
# sample["reconstruction"][talking_head.cfg.data.reconstruction_type[0]]["gt_shape"] = np.zeros((T, 300), dtype=np.float32)
sample["reconstruction"][talking_head.cfg.data.reconstruction_type[0]]["gt_jaw"] = np.zeros((T, 3), dtype=np.float32)
sample["reconstruction"][talking_head.cfg.data.reconstruction_type[0]]["gt_tex"] = np.zeros((50), dtype=np.float32)
sample["reconstruction"][reconstruction_type] = {}
sample["reconstruction"][reconstruction_type]["gt_exp"] = np.zeros((T, 50), dtype=np.float32)
sample["reconstruction"][reconstruction_type]["gt_shape"] = np.zeros((300), dtype=np.float32)
# sample["reconstruction"][reconstruction_type]["gt_shape"] = np.zeros((T, 300), dtype=np.float32)
sample["reconstruction"][reconstruction_type]["gt_jaw"] = np.zeros((T, 3), dtype=np.float32)
sample["reconstruction"][reconstruction_type]["gt_tex"] = np.zeros((50), dtype=np.float32)
sample = create_condition(talking_head, sample)
return sample

Expand Down

0 comments on commit a2f5a0a

Please sign in to comment.