Skip to content

Commit

Permalink
finished vision model
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkoski committed Nov 18, 2024
1 parent 01dc818 commit e19f736
Showing 1 changed file with 131 additions and 116 deletions.
247 changes: 131 additions & 116 deletions Libraries/VLM/Models/Qwen2VL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -460,140 +460,152 @@ private enum Vision {
}
}

fileprivate class PhiMLP: Module, UnaryLayer {
fileprivate class MLP: Module, UnaryLayer {

@ModuleInfo var activation: GELU
@ModuleInfo var fc1: Linear
@ModuleInfo var fc2: Linear

public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
self.fc1 = Linear(config.hiddenSize, config.intermediateSize, bias: true)
self.fc2 = Linear(config.intermediateSize, config.hiddenSize, bias: true)
public init(dimensions: Int, hiddenDimensions: Int) {
self.activation = GELU(approximation: .fast)
self.fc1 = Linear(dimensions, hiddenDimensions)
self.fc2 = Linear(hiddenDimensions, dimensions)
}

public func callAsFunction(_ x: MLXArray) -> MLXArray {
fc2(geluApproximate(fc1(x)))
}
}

fileprivate class EncoderLayer: Module {

@ModuleInfo(key: "self_attn") var attention: Attention
@ModuleInfo(key: "layer_norm1") var layerNorm1: LayerNorm
@ModuleInfo var mlp: PhiMLP
@ModuleInfo(key: "layer_norm2") var layerNorm2: LayerNorm

public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
self._attention.wrappedValue = Attention(
dims: config.hiddenSize, numHeads: config.attentionHeads, bias: true)
self._layerNorm1.wrappedValue = LayerNorm(
dimensions: config.hiddenSize, eps: config.layerNormEps)
self.mlp = PhiMLP(config)
self._layerNorm2.wrappedValue = LayerNorm(
dimensions: config.hiddenSize, eps: config.layerNormEps)
}

public func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil) -> MLXArray {
var r = attention(layerNorm1(x), mask: mask)
let h = x + r
r = mlp(layerNorm2(h))
return h + r
fc2(activation(fc1(x)))
}
}

fileprivate class Encoder: Module {
var layers: [EncoderLayer]


fileprivate class Qwen2VLVisionBlock: Module {

@ModuleInfo var norm1: LayerNorm
@ModuleInfo var norm2: LayerNorm
@ModuleInfo(key: "attn") var attention: Attention
@ModuleInfo var mlp: MLP

public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
self.layers = (0 ..< config.hiddenLayers).map { _ in
EncoderLayer(config)
}
}

public func callAsFunction(
_ x: MLXArray, outputHiddenStates: Bool = false, mask: MLXArray? = nil
) -> (MLXArray, [MLXArray]?) {
var encoderStates: [MLXArray]? = outputHiddenStates ? [] : nil
var h = x
var x = x
for l in layers {
x = l(x, mask: mask)
if outputHiddenStates {
encoderStates?.append(x)
}
h = x[0]
}
return (h, encoderStates)
self.norm1 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6)
self.norm2 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6)

self._attention.wrappedValue = Attention(dims: config.embedDimensions, numHeads: config.numHeads)

let mlpHiddenDimensions = Int(Float(config.embedDimensions) * config.mlpRatio)
self.mlp = MLP(dimensions: config.embedDimensions, hiddenDimensions: mlpHiddenDimensions)
}
}

fileprivate class VisionEmbeddings: Module, UnaryLayer {

@ModuleInfo(key: "patch_embedding") var patchEmbedding: Conv2d
@ModuleInfo(key: "position_embedding") var positionEmbedding: Embedding

let positions: Int
let positionIds: MLXArray

public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
self._patchEmbedding.wrappedValue = Conv2d(
inputChannels: config.channels, outputChannels: config.hiddenSize,
kernelSize: .init(config.patchSize), stride: .init(config.patchSize)
)
let d = config.imageSize / config.patchSize
self.positions = d * d
self._positionEmbedding.wrappedValue = Embedding(
embeddingCount: positions, dimensions: config.hiddenSize

func callAsFunction(_ hiddenStates: MLXArray, cuSequenceLengths: MLXArray, rotaryPositionEmbedding: MLXArray) -> MLXArray {
var hiddenStates = hiddenStates + attention(
norm1(hiddenStates),
cuSequenceLengths: cuSequenceLengths,
rotaryPositionEmbedding: rotaryPositionEmbedding
)
self.positionIds = MLXArray(0 ..< positions)[.newAxis, 0...]
}

public func callAsFunction(_ x: MLXArray) -> MLXArray {
var patchEmbeddings = self.patchEmbedding(x)
patchEmbeddings = patchEmbeddings.flattened(start: 1, end: 2)
let embeddings = patchEmbeddings + self.positionEmbedding(self.positionIds)
return embeddings
}
}

fileprivate class SigLipVisionModel: Module {

@ModuleInfo var embeddings: VisionEmbeddings
@ModuleInfo var encoder: Encoder
@ModuleInfo(key: "post_layernorm") var postLayerNorm: LayerNorm

public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
self.embeddings = VisionEmbeddings(config)
self.encoder = Encoder(config)
self._postLayerNorm.wrappedValue = LayerNorm(dimensions: config.hiddenSize)
}

public func callAsFunction(_ x: MLXArray, outputHiddenStates: Bool = false) -> (
MLXArray, MLXArray, MLXArray?
) {
let x = embeddings(x)

let (encoderOutput, hiddenStates) = encoder(x, outputHiddenStates: outputHiddenStates)
let poolerOutput = postLayerNorm(encoderOutput)

return (poolerOutput, x, hiddenStates?.last)
hiddenStates = hiddenStates + mlp(norm2(hiddenStates))
return hiddenStates
}
}

fileprivate class VisionModel: Module {

@ModuleInfo(key: "vision_model") var visionModel: SigLipVisionModel
@ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed
@ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding
@ModuleInfo(key: "blocks") var blocks: [Qwen2VLVisionBlock]
@ModuleInfo(key: "merger") var patchMerger: PatchMerger

let spatialMergeSize: Int

public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
precondition(
config.modelType == "siglip_vision_model",
config.modelType == "qwen2_vl",
"Unsupported modelType: \(config.modelType)")
self._visionModel.wrappedValue = SigLipVisionModel(config)

self.spatialMergeSize = config.spatialMergeSize

self._patchEmbed.wrappedValue = PatchEmbed(
patchSize: config.patchSize,
temporalPatchSize: config.temporalPatchSize,
inChannels: config.inChannels,
embedDimensions: config.embedDimensions)

let headDimensions = config.embedDimensions / config.numHeads
self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding(dimensions: headDimensions, theta: 10_000)

self._blocks.wrappedValue = (0 ..< config.depth).map { _ in
Qwen2VLVisionBlock(config)
}
self.patchMerger = PatchMerger(dimensions: config.hiddenSize, contextDimensions: config.embedDimensions, spatialMergeSize: 2)
}

func rotaryPositionEmbedding(_ gridThw: MLXArray) -> MLXArray {
var positionIds = [MLXArray]()

for row in gridThw {
// TODO NOTE: this evaluates gridThw -- it shouldn't do that
let t = row[0].item(Int.self)
let h = row[1].item(Int.self)
let w = row[2].item(Int.self)

var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1)
hposIds = repeated(hposIds, count: w, axis: 1)
hposIds = hposIds
.reshaped(
h / spatialMergeSize,
spatialMergeSize,
w / spatialMergeSize,
spatialMergeSize)
.transposed(0, 2, 1, 3)
.flattened()

var wposIds = expandedDimensions(MLXArray(0 ..< w), axis: 0)
wposIds = repeated(wposIds, count: h, axis: 0)
wposIds = hposIds
.reshaped(
h / spatialMergeSize,
spatialMergeSize,
w / spatialMergeSize,
spatialMergeSize)
.transposed(0, 2, 1, 3)
.flattened()

let stackedPosIds = stacked([hposIds, wposIds], axis: -1)
positionIds.append(repeated(stackedPosIds, count: t, axis: 0))
}

let indices = concatenated(positionIds, axis: 0)
let maxGridSize = max(gridThw[0..., 1...])
let rotaryPositionEmbedFull = rotaryPositionEmbedding(maxGridSize)[indices]

return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1)
}

public func callAsFunction(_ x: MLXArray, outputHiddenStates: Bool = false) -> (
MLXArray, MLXArray, MLXArray?
) {
visionModel(x, outputHiddenStates: outputHiddenStates)
public func callAsFunction(_ hiddenStates: MLXArray, gridThw: MLXArray) -> MLXArray {
var hiddenStates = patchEmbed(hiddenStates)
let rotaryPositionEmbedding = rotaryPositionEmbedding(gridThw)

// Assuming grid_thw has shape (batch_size, 3)
let batchSize = gridThw.dim(0)

// Calculate cu_seqlens for each item in the batch
var collect = [MLXArray]()
for i in 0 ..< batchSize {
let sequenceLength = gridThw[i, 1] * gridThw[i, 2]

// TODO NOTE: this evaluates gridThw -- it shouldn't do that
let t = gridThw[i, 0].item(Int.self)
collect.append(repeated(sequenceLength, count: t))
}

// Concatenate the cu_seqlens for all items in the batch
var cuSeqLengths = concatenated(collect)

cuSeqLengths = cumsum(cuSeqLengths.asType(Int32.self), axis: 0)
cuSeqLengths = padded(cuSeqLengths, width: [1, 0], mode: .constant, value: MLXArray(0))

for block in blocks {
hiddenStates = block(hiddenStates, cuSequenceLengths: cuSeqLengths, rotaryPositionEmbedding: rotaryPositionEmbedding)
}

return patchMerger(hiddenStates)
}

private func isMLXWeight(_ array: MLXArray) -> Bool {
Expand All @@ -616,15 +628,18 @@ private enum Vision {
if k.contains("position_id") {
// Remove unused position_ids
continue
} else if k.contains("patch_embedding.weight") {
} else if k.contains("patch_embed.proj.weight") {
// TODO: this comment doesn't match -- based on above code I presume
// the first dimension is now B

// PyTorch conv2d weight tensors have shape:
// [out_channels, in_channels, kH, KW]
// MLX conv2d expects the weight be of shape:
// [out_channels, kH, KW, in_channels]
if isMLXWeight(v) {
sanitizedWeights[k] = v
} else {
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
sanitizedWeights[k] = v.transposed(0, 2, 3, 4, 1)
}
} else {
sanitizedWeights[k] = v
Expand Down Expand Up @@ -882,8 +897,8 @@ public struct Qwen2VLConfiguration: Codable, Sendable {
public let patchSize: Int
public let vocabularySize: Int
public let mlpRatio: Float
public let _channels: Int?
public var channels: Int { _channels ?? 3 }
public let _inChannels: Int?
public var inChannels: Int { _inChannels ?? 3 }
public let _layerNormEps: Float?
public var layerNormEps: Float { _layerNormEps ?? 1e-6 }
public let spatialPatchSize: Int
Expand All @@ -900,7 +915,7 @@ public struct Qwen2VLConfiguration: Codable, Sendable {
case patchSize = "patch_size"
case vocabularySize = "vocab_size"
case mlpRatio = "mlp_ratio"
case _channels = "num_channels"
case _inChannels = "in_channels"
case _layerNormEps = "layer_norm_eps"
case spatialPatchSize = "spatial_patch_size"
case spatialMergeSize = "spatial_merge_size"
Expand Down

0 comments on commit e19f736

Please sign in to comment.