Skip to content

Commit

Permalink
Fix the link error when building compress_weights with Clang on macOS
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Feb 8, 2025
1 parent b18bd78 commit 3a5a6db
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
58 changes: 22 additions & 36 deletions gemma/weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
}
}

template <class Weight>
void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
template <>
void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
if (attn_vec_einsum_w.data() == nullptr) return;

const size_t model_dim = layer_config.model_dim;
Expand All @@ -271,48 +271,34 @@ void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
att_weights.SetPtr(*storage);
}

if (hwy::IsSame<Weight, NuqStream>()) {
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
const hwy::HWY_NAMESPACE::ScalableTag<float> df;

hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
hwy::AlignedFreeUniquePtr<float[]> att_weights_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
hwy::AlignedFreeUniquePtr<float[]> att_weights_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);

HWY_NAMESPACE::DecompressAndZeroPad(
df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), 0,
attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim);

for (size_t m = 0; m < model_dim; ++m) {
float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(float));
}
}

CompressWorkingSet work;
hwy::ThreadPool pool(0);

HWY_NAMESPACE::Compress(
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
MakeSpan(att_weights.data(), model_dim * heads * qkv_dim),
/*packed_ofs=*/0, pool);

att_weights.set_scale(attn_vec_einsum_w.scale());

return;
}
HWY_NAMESPACE::DecompressAndZeroPad(
df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), 0,
attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim);

for (size_t m = 0; m < model_dim; ++m) {
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(float));
}
}

CompressWorkingSet work;
hwy::ThreadPool pool(0);

HWY_NAMESPACE::Compress(
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
MakeSpan(att_weights.data(), model_dim * heads * qkv_dim),
/*packed_ofs=*/0, pool);

att_weights.set_scale(attn_vec_einsum_w.scale());
}

Expand Down
25 changes: 24 additions & 1 deletion gemma/weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,30 @@ struct LayerWeightsPtrs {
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
// after loading weights via ForEachTensor.
// TODO: update compression/convert_weights to bake this in.
void Reshape(MatStorage* storage);
void Reshape(MatStorage* storage) {
static_assert(!hwy::IsSame<Weight, NuqStream>());

if (attn_vec_einsum_w.data() == nullptr) return;

const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
const size_t qkv_dim = layer_config.qkv_dim;

// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
if (storage != nullptr) {
storage->Allocate();
att_weights.SetPtr(*storage);
}
for (size_t m = 0; m < model_dim; ++m) {
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
}
}
att_weights.set_scale(attn_vec_einsum_w.scale());
}

// Used by ForEachTensor for per-layer tensors.
#define GEMMA_CALL_FUNC(member) \
Expand Down

0 comments on commit 3a5a6db

Please sign in to comment.