diff --git a/extra/prompt_tuning/tuning/loaders.py b/extra/prompt_tuning/tuning/loaders.py index 2cc7dc96..9553643c 100644 --- a/extra/prompt_tuning/tuning/loaders.py +++ b/extra/prompt_tuning/tuning/loaders.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, Iterable, List +from typing import Iterable, List import dspy.datasets from dspy import Example +from omegaconf import DictConfig class DataLoader(ABC): @@ -10,7 +11,7 @@ class DataLoader(ABC): Data loader. """ - def __init__(self, config: Dict) -> None: + def __init__(self, config: DictConfig) -> None: self.config = config @abstractmethod