Skip to content

fix conversion for text embeddings for fp16 models (#968) #3954

fix conversion for text embeddings for fp16 models (#968)

fix conversion for text embeddings for fp16 models (#968) #3954

Workflow file for this run

name: IPEX - Test
on:
push:
branches:
- main
- v*-release
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build:
strategy:
fail-fast: false
matrix:
torch-version: ["2.2.0", "2.3.*", "2.4.*"]
transformers-version: ["4.39.0", "4.44.*"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9
- name: Install dependencies
run: |
pip install --upgrade pip
pip install torch==${{ matrix.torch-version }} torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install .[ipex,tests] transformers[testing]==${{ matrix.transformers-version }} intel_extension_for_pytorch==${{ matrix.torch-version }}
- if: ${{ matrix.torch-version == '2.2.0' }}
name: Downgrade Numpy
run: pip install numpy==1.*
- name: Assert versions
run: |
python -c "import torch; print(torch.__version__); assert torch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))"
python -c "import intel_extension_for_pytorch; print(intel_extension_for_pytorch.__version__); assert intel_extension_for_pytorch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))"
python -c "import transformers; print(transformers.__version__); assert transformers.__version__.startswith('${{ matrix.transformers-version }}'.replace('.*', ''))"
- name: Test with Pytest
run: |
pytest tests/ipex