diff --git a/optimum/exporters/neuron/model_configs/decoder_configs.py b/optimum/exporters/neuron/model_configs/decoder_configs.py index 1dd30fb0b..7013090bc 100644 --- a/optimum/exporters/neuron/model_configs/decoder_configs.py +++ b/optimum/exporters/neuron/model_configs/decoder_configs.py @@ -12,12 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Neuron export configurations for decoder models.""" - -import importlib - -from transformers_neuronx import ContinuousBatchingConfig -from transformers_neuronx import NeuronConfig as TnxNeuronConfig +"""Neuron export configurations for models using transformers_neuronx.""" from optimum.exporters.tasks import TasksManager @@ -168,6 +163,7 @@ class GraniteNeuronConfig(NeuronDecoderExportConfig): NEURONX_CLASS = GraniteForSampling CONTINUOUS_BATCHING = True + @register_in_tasks_manager("phi4", "text-generation") class Phi4NeuronConfig(TextNeuronDecoderConfig): NEURONX_CLASS = Phi4ForSampling