From 8a3f2d86b2026097c1cdb960b71c3bd150a6b903 Mon Sep 17 00:00:00 2001 From: Pete Date: Fri, 2 Feb 2024 10:36:24 -0800 Subject: [PATCH] Fix HF integration for Python < 3.10 (#426) --- .github/workflows/main.yml | 2 +- .github/workflows/pr_checks.yml | 1 + CHANGELOG.md | 4 ++++ hf_olmo/modeling_olmo.py | 5 +++-- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4e68986c5..adac82794 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -104,7 +104,7 @@ jobs: gpu_tests: name: GPU Tests runs-on: ubuntu-latest - timeout-minutes: 15 + timeout-minutes: 8 env: BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} BEAKER_IMAGE: olmo-torch2-test diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index 12f918aaf..6accc241a 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -10,6 +10,7 @@ on: - main paths: - 'olmo/**' + - 'hf_olmo/**' jobs: changelog: diff --git a/CHANGELOG.md b/CHANGELOG.md index c9cec4836..d5a33a19f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed + +- Fixed an issue with the HuggingFace integration where we were inadvertently using a feature that was introduced in Python 3.10, causing an error for older Python versions. + ## [v0.2.3](https://github.com/allenai/OLMo/releases/tag/v0.2.3) - 2024-01-31 ## [v0.2.2](https://github.com/allenai/LLM/releases/tag/v0.2.2) - 2023-12-10 diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index cf4496305..8856be8ad 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -1,3 +1,4 @@ +from dataclasses import fields from typing import List, Optional, Tuple, Union import torch @@ -17,8 +18,8 @@ def create_model_config_from_pretrained_config(config: OLMoConfig): """ kwargs = {} - for key in ModelConfig.__match_args__: - kwargs[key] = getattr(config, key) + for field in fields(ModelConfig): + kwargs[field.name] = getattr(config, field.name) model_config = ModelConfig(**kwargs) return model_config