From 2d614f33c33483cf25cc5ac786199465f2ff4e44 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Tue, 5 Sep 2023 10:01:43 +0800 Subject: [PATCH] feat: add native source to git (#118) --- cpp/.gitignore | 6 - cpp/coreml/whisper-decoder-impl.h | 146 + cpp/coreml/whisper-decoder-impl.m | 201 + cpp/coreml/whisper-encoder-impl.h | 142 + cpp/coreml/whisper-encoder-impl.m | 197 + cpp/coreml/whisper-encoder.h | 22 + cpp/coreml/whisper-encoder.mm | 65 + cpp/ggml.c | 18740 ++++++++++++++++++++++++++++ cpp/ggml.h | 1541 +++ cpp/whisper.cpp | 5512 ++++++++ cpp/whisper.h | 531 + 11 files changed, 27097 insertions(+), 6 deletions(-) delete mode 100644 cpp/.gitignore create mode 100644 cpp/coreml/whisper-decoder-impl.h create mode 100644 cpp/coreml/whisper-decoder-impl.m create mode 100644 cpp/coreml/whisper-encoder-impl.h create mode 100644 cpp/coreml/whisper-encoder-impl.m create mode 100644 cpp/coreml/whisper-encoder.h create mode 100644 cpp/coreml/whisper-encoder.mm create mode 100644 cpp/ggml.c create mode 100644 cpp/ggml.h create mode 100644 cpp/whisper.cpp create mode 100644 cpp/whisper.h diff --git a/cpp/.gitignore b/cpp/.gitignore deleted file mode 100644 index 4f7c8ae..0000000 --- a/cpp/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -*.c -*.h -*.cpp -*.m -*.mm -!rn-whisper* diff --git a/cpp/coreml/whisper-decoder-impl.h b/cpp/coreml/whisper-decoder-impl.h new file mode 100644 index 0000000..c6f2e85 --- /dev/null +++ b/cpp/coreml/whisper-decoder-impl.h @@ -0,0 +1,146 @@ +// +// whisper-decoder-impl.h +// +// This file was automatically generated and should not be edited. +// + +#import +#import +#include +#include + +NS_ASSUME_NONNULL_BEGIN + + +/// Model Prediction Input Type +API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden"))) +@interface whisper_decoder_implInput : NSObject + +/// token_data as 1 by 1 matrix of 32-bit integers +@property (readwrite, nonatomic, strong) MLMultiArray * token_data; + +/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats +@property (readwrite, nonatomic, strong) MLMultiArray * audio_data; +- (instancetype)init NS_UNAVAILABLE; +- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER; + +@end + + +/// Model Prediction Output Type +API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden"))) +@interface whisper_decoder_implOutput : NSObject + +/// var_1346 as multidimensional array of floats +@property (readwrite, nonatomic, strong) MLMultiArray * var_1346; +- (instancetype)init NS_UNAVAILABLE; +- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER; + +@end + + +/// Class for model loading and prediction +API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden"))) +@interface whisper_decoder_impl : NSObject +@property (readonly, nonatomic, nullable) MLModel * model; + +/** + URL of the underlying .mlmodelc directory. +*/ ++ (nullable NSURL *)URLOfModelInThisBundle; + +/** + Initialize whisper_decoder_impl instance from an existing MLModel object. + + Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl. + Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in. +*/ +- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER; + +/** + Initialize whisper_decoder_impl instance with the model in this bundle. +*/ +- (nullable instancetype)init; + +/** + Initialize whisper_decoder_impl instance with the model in this bundle. + + @param configuration The model configuration object + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Initialize whisper_decoder_impl instance from the model URL. + + @param modelURL URL to the .mlmodelc directory for whisper_decoder_impl. + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Initialize whisper_decoder_impl instance from the model URL. + + @param modelURL URL to the .mlmodelc directory for whisper_decoder_impl. + @param configuration The model configuration object + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Construct whisper_decoder_impl instance asynchronously with configuration. + Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread. + + @param configuration The model configuration + @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object. +*/ ++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler; + +/** + Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration. + + Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread. + + @param modelURL The model URL. + @param configuration The model configuration + @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object. +*/ ++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler; + +/** + Make a prediction using the standard interface + @param input an instance of whisper_decoder_implInput to predict from + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. + @return the prediction as whisper_decoder_implOutput +*/ +- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Make a prediction using the standard interface + @param input an instance of whisper_decoder_implInput to predict from + @param options prediction options + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. + @return the prediction as whisper_decoder_implOutput +*/ +- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Make a prediction using the convenience interface + @param token_data as 1 by 1 matrix of 32-bit integers: + @param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats: + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. + @return the prediction as whisper_decoder_implOutput +*/ +- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Batch prediction + @param inputArray array of whisper_decoder_implInput instances to obtain predictions from + @param options prediction options + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. + @return the predictions as NSArray +*/ +- (nullable NSArray *)predictionsFromInputs:(NSArray *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error; +@end + +NS_ASSUME_NONNULL_END diff --git a/cpp/coreml/whisper-decoder-impl.m b/cpp/coreml/whisper-decoder-impl.m new file mode 100644 index 0000000..34060e4 --- /dev/null +++ b/cpp/coreml/whisper-decoder-impl.m @@ -0,0 +1,201 @@ +// +// whisper-decoder-impl.m +// +// This file was automatically generated and should not be edited. +// + +#if !__has_feature(objc_arc) +#error This file must be compiled with automatic reference counting enabled (-fobjc-arc) +#endif + +#import "whisper-decoder-impl.h" + +@implementation whisper_decoder_implInput + +- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data { + self = [super init]; + if (self) { + _token_data = token_data; + _audio_data = audio_data; + } + return self; +} + +- (NSSet *)featureNames { + return [NSSet setWithArray:@[@"token_data", @"audio_data"]]; +} + +- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName { + if ([featureName isEqualToString:@"token_data"]) { + return [MLFeatureValue featureValueWithMultiArray:self.token_data]; + } + if ([featureName isEqualToString:@"audio_data"]) { + return [MLFeatureValue featureValueWithMultiArray:self.audio_data]; + } + return nil; +} + +@end + +@implementation whisper_decoder_implOutput + +- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 { + self = [super init]; + if (self) { + _var_1346 = var_1346; + } + return self; +} + +- (NSSet *)featureNames { + return [NSSet setWithArray:@[@"var_1346"]]; +} + +- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName { + if ([featureName isEqualToString:@"var_1346"]) { + return [MLFeatureValue featureValueWithMultiArray:self.var_1346]; + } + return nil; +} + +@end + +@implementation whisper_decoder_impl + + +/** + URL of the underlying .mlmodelc directory. +*/ ++ (nullable NSURL *)URLOfModelInThisBundle { + NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_decoder_impl" ofType:@"mlmodelc"]; + if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-decoder-impl.mlmodelc in the bundle resource"); return nil; } + return [NSURL fileURLWithPath:assetPath]; +} + + +/** + Initialize whisper_decoder_impl instance from an existing MLModel object. + + Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl. + Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in. +*/ +- (instancetype)initWithMLModel:(MLModel *)model { + self = [super init]; + if (!self) { return nil; } + _model = model; + if (_model == nil) { return nil; } + return self; +} + + +/** + Initialize whisper_decoder_impl instance with the model in this bundle. +*/ +- (nullable instancetype)init { + return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil]; +} + + +/** + Initialize whisper_decoder_impl instance with the model in this bundle. + + @param configuration The model configuration object + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error { + return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error]; +} + + +/** + Initialize whisper_decoder_impl instance from the model URL. + + @param modelURL URL to the .mlmodelc directory for whisper_decoder_impl. + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error { + MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error]; + if (model == nil) { return nil; } + return [self initWithMLModel:model]; +} + + +/** + Initialize whisper_decoder_impl instance from the model URL. + + @param modelURL URL to the .mlmodelc directory for whisper_decoder_impl. + @param configuration The model configuration object + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error { + MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error]; + if (model == nil) { return nil; } + return [self initWithMLModel:model]; +} + + +/** + Construct whisper_decoder_impl instance asynchronously with configuration. + Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread. + + @param configuration The model configuration + @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object. +*/ ++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler { + [self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle] + configuration:configuration + completionHandler:handler]; +} + + +/** + Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration. + + Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread. + + @param modelURL The model URL. + @param configuration The model configuration + @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object. +*/ ++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler { + [MLModel loadContentsOfURL:modelURL + configuration:configuration + completionHandler:^(MLModel *model, NSError *error) { + if (model != nil) { + whisper_decoder_impl *typedModel = [[whisper_decoder_impl alloc] initWithMLModel:model]; + handler(typedModel, nil); + } else { + handler(nil, error); + } + }]; +} + +- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error { + return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error]; +} + +- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error { + id outFeatures = [self.model predictionFromFeatures:input options:options error:error]; + if (!outFeatures) { return nil; } + return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue]; +} + +- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error { + whisper_decoder_implInput *input_ = [[whisper_decoder_implInput alloc] initWithToken_data:token_data audio_data:audio_data]; + return [self predictionFromFeatures:input_ error:error]; +} + +- (nullable NSArray *)predictionsFromInputs:(NSArray *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error { + id inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray]; + id outBatch = [self.model predictionsFromBatch:inBatch options:options error:error]; + if (!outBatch) { return nil; } + NSMutableArray *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count]; + for (NSInteger i = 0; i < outBatch.count; i++) { + id resultProvider = [outBatch featuresAtIndex:i]; + whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue]; + [results addObject:result]; + } + return results; +} + +@end diff --git a/cpp/coreml/whisper-encoder-impl.h b/cpp/coreml/whisper-encoder-impl.h new file mode 100644 index 0000000..ecb6155 --- /dev/null +++ b/cpp/coreml/whisper-encoder-impl.h @@ -0,0 +1,142 @@ +// +// whisper-encoder-impl.h +// +// This file was automatically generated and should not be edited. +// + +#import +#import +#include +#include + +NS_ASSUME_NONNULL_BEGIN + + +/// Model Prediction Input Type +API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden"))) +@interface whisper_encoder_implInput : NSObject + +/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats +@property (readwrite, nonatomic, strong) MLMultiArray * logmel_data; +- (instancetype)init NS_UNAVAILABLE; +- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data NS_DESIGNATED_INITIALIZER; + +@end + + +/// Model Prediction Output Type +API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden"))) +@interface whisper_encoder_implOutput : NSObject + +/// output as multidimensional array of floats +@property (readwrite, nonatomic, strong) MLMultiArray * output; +- (instancetype)init NS_UNAVAILABLE; +- (instancetype)initWithOutput:(MLMultiArray *)output NS_DESIGNATED_INITIALIZER; + +@end + + +/// Class for model loading and prediction +API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden"))) +@interface whisper_encoder_impl : NSObject +@property (readonly, nonatomic, nullable) MLModel * model; + +/** + URL of the underlying .mlmodelc directory. +*/ ++ (nullable NSURL *)URLOfModelInThisBundle; + +/** + Initialize whisper_encoder_impl instance from an existing MLModel object. + + Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl. + Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in. +*/ +- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER; + +/** + Initialize whisper_encoder_impl instance with the model in this bundle. +*/ +- (nullable instancetype)init; + +/** + Initialize whisper_encoder_impl instance with the model in this bundle. + + @param configuration The model configuration object + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Initialize whisper_encoder_impl instance from the model URL. + + @param modelURL URL to the .mlmodelc directory for whisper_encoder_impl. + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Initialize whisper_encoder_impl instance from the model URL. + + @param modelURL URL to the .mlmodelc directory for whisper_encoder_impl. + @param configuration The model configuration object + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Construct whisper_encoder_impl instance asynchronously with configuration. + Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread. + + @param configuration The model configuration + @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object. +*/ ++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler; + +/** + Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration. + + Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread. + + @param modelURL The model URL. + @param configuration The model configuration + @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object. +*/ ++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler; + +/** + Make a prediction using the standard interface + @param input an instance of whisper_encoder_implInput to predict from + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. + @return the prediction as whisper_encoder_implOutput +*/ +- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Make a prediction using the standard interface + @param input an instance of whisper_encoder_implInput to predict from + @param options prediction options + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. + @return the prediction as whisper_encoder_implOutput +*/ +- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Make a prediction using the convenience interface + @param logmel_data as 1 × 80 × 3000 3-dimensional array of floats: + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. + @return the prediction as whisper_encoder_implOutput +*/ +- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error; + +/** + Batch prediction + @param inputArray array of whisper_encoder_implInput instances to obtain predictions from + @param options prediction options + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. + @return the predictions as NSArray +*/ +- (nullable NSArray *)predictionsFromInputs:(NSArray *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error; +@end + +NS_ASSUME_NONNULL_END diff --git a/cpp/coreml/whisper-encoder-impl.m b/cpp/coreml/whisper-encoder-impl.m new file mode 100644 index 0000000..ee8e506 --- /dev/null +++ b/cpp/coreml/whisper-encoder-impl.m @@ -0,0 +1,197 @@ +// +// whisper-encoder-impl.m +// +// This file was automatically generated and should not be edited. +// + +#if !__has_feature(objc_arc) +#error This file must be compiled with automatic reference counting enabled (-fobjc-arc) +#endif + +#import "whisper-encoder-impl.h" + +@implementation whisper_encoder_implInput + +- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data { + self = [super init]; + if (self) { + _logmel_data = logmel_data; + } + return self; +} + +- (NSSet *)featureNames { + return [NSSet setWithArray:@[@"logmel_data"]]; +} + +- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName { + if ([featureName isEqualToString:@"logmel_data"]) { + return [MLFeatureValue featureValueWithMultiArray:self.logmel_data]; + } + return nil; +} + +@end + +@implementation whisper_encoder_implOutput + +- (instancetype)initWithOutput:(MLMultiArray *)output { + self = [super init]; + if (self) { + _output = output; + } + return self; +} + +- (NSSet *)featureNames { + return [NSSet setWithArray:@[@"output"]]; +} + +- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName { + if ([featureName isEqualToString:@"output"]) { + return [MLFeatureValue featureValueWithMultiArray:self.output]; + } + return nil; +} + +@end + +@implementation whisper_encoder_impl + + +/** + URL of the underlying .mlmodelc directory. +*/ ++ (nullable NSURL *)URLOfModelInThisBundle { + NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_encoder_impl" ofType:@"mlmodelc"]; + if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-encoder-impl.mlmodelc in the bundle resource"); return nil; } + return [NSURL fileURLWithPath:assetPath]; +} + + +/** + Initialize whisper_encoder_impl instance from an existing MLModel object. + + Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl. + Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in. +*/ +- (instancetype)initWithMLModel:(MLModel *)model { + self = [super init]; + if (!self) { return nil; } + _model = model; + if (_model == nil) { return nil; } + return self; +} + + +/** + Initialize whisper_encoder_impl instance with the model in this bundle. +*/ +- (nullable instancetype)init { + return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil]; +} + + +/** + Initialize whisper_encoder_impl instance with the model in this bundle. + + @param configuration The model configuration object + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error { + return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error]; +} + + +/** + Initialize whisper_encoder_impl instance from the model URL. + + @param modelURL URL to the .mlmodelc directory for whisper_encoder_impl. + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error { + MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error]; + if (model == nil) { return nil; } + return [self initWithMLModel:model]; +} + + +/** + Initialize whisper_encoder_impl instance from the model URL. + + @param modelURL URL to the .mlmodelc directory for whisper_encoder_impl. + @param configuration The model configuration object + @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. +*/ +- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error { + MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error]; + if (model == nil) { return nil; } + return [self initWithMLModel:model]; +} + + +/** + Construct whisper_encoder_impl instance asynchronously with configuration. + Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread. + + @param configuration The model configuration + @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object. +*/ ++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler { + [self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle] + configuration:configuration + completionHandler:handler]; +} + + +/** + Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration. + + Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread. + + @param modelURL The model URL. + @param configuration The model configuration + @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object. +*/ ++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler { + [MLModel loadContentsOfURL:modelURL + configuration:configuration + completionHandler:^(MLModel *model, NSError *error) { + if (model != nil) { + whisper_encoder_impl *typedModel = [[whisper_encoder_impl alloc] initWithMLModel:model]; + handler(typedModel, nil); + } else { + handler(nil, error); + } + }]; +} + +- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error { + return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error]; +} + +- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error { + id outFeatures = [self.model predictionFromFeatures:input options:options error:error]; + if (!outFeatures) { return nil; } + return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue]; +} + +- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error { + whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data]; + return [self predictionFromFeatures:input_ error:error]; +} + +- (nullable NSArray *)predictionsFromInputs:(NSArray *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error { + id inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray]; + id outBatch = [self.model predictionsFromBatch:inBatch options:options error:error]; + if (!outBatch) { return nil; } + NSMutableArray *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count]; + for (NSInteger i = 0; i < outBatch.count; i++) { + id resultProvider = [outBatch featuresAtIndex:i]; + whisper_encoder_implOutput * result = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[resultProvider featureValueForName:@"output"].multiArrayValue]; + [results addObject:result]; + } + return results; +} + +@end diff --git a/cpp/coreml/whisper-encoder.h b/cpp/coreml/whisper-encoder.h new file mode 100644 index 0000000..84bbe41 --- /dev/null +++ b/cpp/coreml/whisper-encoder.h @@ -0,0 +1,22 @@ +// Wrapper of the Core ML Whisper Encoder model +// +// Code is derived from the work of Github user @wangchou +// ref: https://github.com/wangchou/callCoreMLFromCpp + +#if __cplusplus +extern "C" { +#endif + +struct whisper_coreml_context; + +struct whisper_coreml_context * whisper_coreml_init(const char * path_model); +void whisper_coreml_free(struct whisper_coreml_context * ctx); + +void whisper_coreml_encode( + const whisper_coreml_context * ctx, + float * mel, + float * out); + +#if __cplusplus +} +#endif diff --git a/cpp/coreml/whisper-encoder.mm b/cpp/coreml/whisper-encoder.mm new file mode 100644 index 0000000..6cd90ed --- /dev/null +++ b/cpp/coreml/whisper-encoder.mm @@ -0,0 +1,65 @@ +#if !__has_feature(objc_arc) +#error This file must be compiled with automatic reference counting enabled (-fobjc-arc) +#endif + +#import "whisper-encoder.h" +#import "whisper-encoder-impl.h" + +#import + +#include + +#if __cplusplus +extern "C" { +#endif + +struct whisper_coreml_context { + const void * data; +}; + +struct whisper_coreml_context * whisper_coreml_init(const char * path_model) { + NSString * path_model_str = [[NSString alloc] initWithUTF8String:path_model]; + + NSURL * url_model = [NSURL fileURLWithPath: path_model_str]; + + const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]); + + if (data == NULL) { + return NULL; + } + + whisper_coreml_context * ctx = new whisper_coreml_context; + + ctx->data = data; + + return ctx; +} + +void whisper_coreml_free(struct whisper_coreml_context * ctx) { + CFRelease(ctx->data); + delete ctx; +} + +void whisper_coreml_encode( + const whisper_coreml_context * ctx, + float * mel, + float * out) { + MLMultiArray * inMultiArray = [ + [MLMultiArray alloc] initWithDataPointer: mel + shape: @[@1, @80, @3000] + dataType: MLMultiArrayDataTypeFloat32 + strides: @[@(240000), @(3000), @1] + deallocator: nil + error: nil + ]; + + @autoreleasepool { + whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil]; + + memcpy(out, outCoreML.output.dataPointer, outCoreML.output.count * sizeof(float)); + } +} + +#if __cplusplus +} +#endif diff --git a/cpp/ggml.c b/cpp/ggml.c new file mode 100644 index 0000000..4e2695d --- /dev/null +++ b/cpp/ggml.c @@ -0,0 +1,18740 @@ +#define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux +#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows + +#include "ggml.h" + +#ifdef WSP_GGML_USE_K_QUANTS +#include "k_quants.h" +#endif + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef WSP_GGML_USE_METAL +#include +#endif + +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef static_assert +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif + +#if defined(_MSC_VER) +// disable "possible loss of data" to avoid hundreds of casts +// we should just be careful :) +#pragma warning(disable: 4244 4267) +#endif + +#if defined(_WIN32) + +#include + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; + +static void atomic_store(atomic_int* ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int* ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { + return atomic_fetch_add(ptr, -(dec)); +} + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { + (void) unused; + HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); + if (handle == NULL) + { + return EAGAIN; + } + + *out = handle; + return 0; +} + +static int pthread_join(pthread_t thread, void* unused) { + (void) unused; + return (int) WaitForSingleObject(thread, INFINITE); +} + +static int sched_yield (void) { + Sleep (0); + return 0; +} +#else +#include +#include + +typedef void* thread_ret_t; + +#include +#include +#include + +#endif + +// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 +#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __FMA__ +#define __FMA__ +#endif +#ifndef __F16C__ +#define __F16C__ +#endif +#ifndef __SSE3__ +#define __SSE3__ +#endif +#endif + +#ifdef __HAIKU__ +#define static_assert(cond, msg) _Static_assert(cond, msg) +#endif + +/*#define WSP_GGML_PERF*/ +#define WSP_GGML_DEBUG 0 +#define WSP_GGML_GELU_FP16 +#define WSP_GGML_GELU_QUICK_FP16 +#define WSP_GGML_SILU_FP16 + +#define WSP_GGML_SOFT_MAX_UNROLL 4 +#define WSP_GGML_VEC_DOT_UNROLL 2 + +// +// logging +// + +#if (WSP_GGML_DEBUG >= 1) +#define WSP_GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define WSP_GGML_PRINT_DEBUG(...) +#endif + +#if (WSP_GGML_DEBUG >= 5) +#define WSP_GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define WSP_GGML_PRINT_DEBUG_5(...) +#endif + +#if (WSP_GGML_DEBUG >= 10) +#define WSP_GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define WSP_GGML_PRINT_DEBUG_10(...) +#endif + +#define WSP_GGML_PRINT(...) printf(__VA_ARGS__) + +#ifdef WSP_GGML_USE_ACCELERATE +// uncomment to use vDSP for soft max computation +// note: not sure if it is actually faster +//#define WSP_GGML_SOFT_MAX_ACCELERATE +#endif + +#if UINTPTR_MAX == 0xFFFFFFFF + #define WSP_GGML_MEM_ALIGN 4 +#else + #define WSP_GGML_MEM_ALIGN 16 +#endif + +// +// logging +// + +#if (WSP_GGML_DEBUG >= 1) +#define WSP_GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define WSP_GGML_PRINT_DEBUG(...) +#endif + +#if (WSP_GGML_DEBUG >= 5) +#define WSP_GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define WSP_GGML_PRINT_DEBUG_5(...) +#endif + +#if (WSP_GGML_DEBUG >= 10) +#define WSP_GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define WSP_GGML_PRINT_DEBUG_10(...) +#endif + +#define WSP_GGML_PRINT(...) printf(__VA_ARGS__) + +// +// end of logging block +// + +#if defined(_MSC_VER) || defined(__MINGW32__) +#define WSP_GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, WSP_GGML_MEM_ALIGN) +#define WSP_GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) +#else +inline static void* wsp_ggml_aligned_malloc(size_t size) { + void* aligned_memory = NULL; +#ifdef WSP_GGML_USE_METAL + int result = posix_memalign(&aligned_memory, getpagesize(), size); +#else + int result = posix_memalign(&aligned_memory, WSP_GGML_MEM_ALIGN, size); +#endif + if (result != 0) { + // Handle allocation failure + const char *error_desc = "unknown allocation error"; + switch (result) { + case EINVAL: + error_desc = "invalid alignment value"; + break; + case ENOMEM: + error_desc = "insufficient memory"; + break; + } + WSP_GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", + __func__, error_desc, size/(1024.0*1024.0)); + return NULL; + } + return aligned_memory; +} +#define WSP_GGML_ALIGNED_MALLOC(size) wsp_ggml_aligned_malloc(size) +#define WSP_GGML_ALIGNED_FREE(ptr) free(ptr) +#endif + +#define UNUSED WSP_GGML_UNUSED +#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) + +// +// tensor access macros +// + +#define WSP_GGML_TENSOR_UNARY_OP_LOCALS \ + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ + WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + +#define WSP_GGML_TENSOR_BINARY_OP_LOCALS \ + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ + WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ + WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \ + WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \ + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + +#if defined(WSP_GGML_USE_ACCELERATE) +#include +#if defined(WSP_GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions +#include "ggml-opencl.h" +#endif +#elif defined(WSP_GGML_USE_OPENBLAS) +#include +#elif defined(WSP_GGML_USE_CUBLAS) +#include "ggml-cuda.h" +#elif defined(WSP_GGML_USE_CLBLAST) +#include "ggml-opencl.h" +#endif + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// floating point type used to accumulate sums +typedef double wsp_ggml_float; + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#ifdef __ARM_NEON + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) (x) + +#define WSP_GGML_FP16_TO_FP32(x) ((float) (x)) +#define WSP_GGML_FP32_TO_FP16(x) (x) + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#include +#endif +#endif +#endif +#endif + +#ifdef __F16C__ + +#ifdef _MSC_VER +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) +#else +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) +#endif + +#elif defined(__POWER9_VECTOR__) + +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x) +/* the inline asm below is about 12% faster than the lookup method */ +#define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x) +#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x) + +static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) { + register float f; + register double d; + __asm__( + "mtfprd %0,%2\n" + "xscvhpdp %0,%0\n" + "frsp %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=f"(f): + /* in */ "r"(h)); + return f; +} + +static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) { + register double d; + register wsp_ggml_fp16_t r; + __asm__( /* xscvdphp can work on double or single precision */ + "xscvdphp %0,%2\n" + "mffprd %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=r"(r): + /* in */ "f"(f)); + return r; +} + +#else + +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; +} + +static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // __ARM_NEON + +// +// global data +// + +// precomputed gelu table for f16 (128 KB) +static wsp_ggml_fp16_t table_gelu_f16[1 << 16]; + +// precomputed quick gelu table for f16 (128 KB) +static wsp_ggml_fp16_t table_gelu_quick_f16[1 << 16]; + +// precomputed silu table for f16 (128 KB) +static wsp_ggml_fp16_t table_silu_f16[1 << 16]; + +// precomputed exp table for f16 (128 KB) +static wsp_ggml_fp16_t table_exp_f16[1 << 16]; + +// precomputed f32 table for f16 (256 KB) +static float table_f32_f16[1 << 16]; + +#if defined(__ARM_NEON) || defined(__wasm_simd128__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into wsp_ggml_lookup_fp16_to_fp32, +// so we define WSP_GGML_FP16_TO_FP32 and WSP_GGML_FP32_TO_FP16 elsewhere for NEON. +// This is also true for POWER9. +#if !defined(WSP_GGML_FP16_TO_FP32) || !defined(WSP_GGML_FP32_TO_FP16) + +inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return table_f32_f16[s]; +} + +#define WSP_GGML_FP16_TO_FP32(x) wsp_ggml_lookup_fp16_to_fp32(x) +#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x) + +#endif + +// note: do not use these inside ggml.c +// these are meant to be used via the ggml.h API +float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t x) { + return (float) WSP_GGML_FP16_TO_FP32(x); +} + +wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float x) { + return WSP_GGML_FP32_TO_FP16(x); +} + +void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t * x, float * y, size_t n) { + for (size_t i = 0; i < n; i++) { + y[i] = WSP_GGML_FP16_TO_FP32(x[i]); + } +} + +void wsp_ggml_fp32_to_fp16_row(const float * x, wsp_ggml_fp16_t * y, size_t n) { + size_t i = 0; +#if defined(__F16C__) + for (; i + 7 < n; i += 8) { + __m256 x_vec = _mm256_loadu_ps(x + i); + __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128((__m128i *)(y + i), y_vec); + } + for(; i + 3 < n; i += 4) { + __m128 x_vec = _mm_loadu_ps(x + i); + __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storel_epi64((__m128i *)(y + i), y_vec); + } +#endif + for (; i < n; i++) { + y[i] = WSP_GGML_FP32_TO_FP16(x[i]); + } +} + +// +// timing +// + +#if defined(_MSC_VER) || defined(__MINGW32__) +static int64_t timer_freq, timer_start; +void wsp_ggml_time_init(void) { + LARGE_INTEGER t; + QueryPerformanceFrequency(&t); + timer_freq = t.QuadPart; + + // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq + // and the uptime is high enough. + // We subtract the program start time to reduce the likelihood of that happening. + QueryPerformanceCounter(&t); + timer_start = t.QuadPart; +} +int64_t wsp_ggml_time_ms(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart-timer_start) * 1000) / timer_freq; +} +int64_t wsp_ggml_time_us(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart-timer_start) * 1000000) / timer_freq; +} +#else +void wsp_ggml_time_init(void) {} +int64_t wsp_ggml_time_ms(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; +} + +int64_t wsp_ggml_time_us(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; +} +#endif + +int64_t wsp_ggml_cycles(void) { + return clock(); +} + +int64_t wsp_ggml_cycles_per_ms(void) { + return CLOCKS_PER_SEC/1000; +} + +#ifdef WSP_GGML_PERF +#define wsp_ggml_perf_time_ms() wsp_ggml_time_ms() +#define wsp_ggml_perf_time_us() wsp_ggml_time_us() +#define wsp_ggml_perf_cycles() wsp_ggml_cycles() +#define wsp_ggml_perf_cycles_per_ms() wsp_ggml_cycles_per_ms() +#else +#define wsp_ggml_perf_time_ms() 0 +#define wsp_ggml_perf_time_us() 0 +#define wsp_ggml_perf_cycles() 0 +#define wsp_ggml_perf_cycles_per_ms() 0 +#endif + + +// +// cache line +// + +#if defined(__cpp_lib_hardware_interference_size) +#define CACHE_LINE_SIZE hardware_destructive_interference_size +#else +#if defined(__POWER9_VECTOR__) +#define CACHE_LINE_SIZE 128 +#else +#define CACHE_LINE_SIZE 64 +#endif +#endif + +static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); + +// +// quantization +// + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(x, x); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + const __m128i ones = _mm_set1_epi16(1); + return _mm_madd_epi16(ones, dot); +} + +#if __AVX__ || __AVX2__ || __AVX512F__ +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#if defined(__AVX2__) || defined(__AVX512F__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + return _mm256_and_si256(lowMask, bytes); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { +#if __AVXVNNI__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} +#elif defined(__AVX__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); + __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); + __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); + const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytesl = _mm_or_si128(bytesl, bit_mask); + bytesh = _mm_or_si128(bytesh, bit_mask); + bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); + bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); + return MM256_SET_M128I(bytesh, bytesl); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + // Load 16 bytes from memory + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); + __m128i tmph = _mm_srli_epi16(tmpl, 4); + const __m128i lowMask = _mm_set1_epi8(0xF); + tmpl = _mm_and_si128(lowMask, tmpl); + tmph = _mm_and_si128(lowMask, tmph); + return MM256_SET_M128I(tmph, tmpl); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { + const __m128i ones = _mm_set1_epi16(1); + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m128i xl = _mm256_castsi256_si128(x); + const __m128i xh = _mm256_extractf128_si256(x, 1); + const __m128i yl = _mm256_castsi256_si128(y); + const __m128i yh = _mm256_extractf128_si256(y, 1); + // Get absolute values of x vectors + const __m128i axl = _mm_sign_epi8(xl, xl); + const __m128i axh = _mm_sign_epi8(xh, xh); + // Sign the values of the y vectors + const __m128i syl = _mm_sign_epi8(yl, xl); + const __m128i syh = _mm_sign_epi8(yh, xh); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m128i lowByte = _mm_set1_epi16( 0xFF ); + __m128i high = _mm_andnot_si128( lowByte, bytes1 ); + __m128i low = _mm_and_si128( lowByte, bytes1 ); + high = _mm_srli_epi16( high, 4 ); + bytes1 = _mm_or_si128( low, high ); + high = _mm_andnot_si128( lowByte, bytes2 ); + low = _mm_and_si128( lowByte, bytes2 ); + high = _mm_srli_epi16( high, 4 ); + bytes2 = _mm_or_si128( low, high ); + + return _mm_packus_epi16( bytes1, bytes2); +} +#endif +#elif defined(__SSSE3__) +// horizontally add 4x4 floats +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =_mm_hadd_ps(a, b); + __m128 res_1 =_mm_hadd_ps(c, d); + __m128 res =_mm_hadd_ps(res_0, res_1); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif // __AVX__ || __AVX2__ || __AVX512F__ +#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) + +#if defined(__ARM_NEON) + +#if !defined(__aarch64__) + +inline static uint16_t vaddvq_u8(uint8x16_t v) { + return + (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) + + (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) + + (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) + + (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) + + (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) + + (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) + + (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) + + (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15); +} + +inline static int16_t vaddvq_s8(int8x16_t v) { + return + (int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) + + (int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) + + (int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) + + (int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) + + (int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) + + (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) + + (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) + + (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15); +} + +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static uint32_t vaddvq_u16(uint16x8_t v) { + return + (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) + + (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) + + (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) + + (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline static float vminvq_f32(float32x4_t v) { + return + MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static float vmaxvq_f32(float32x4_t v) { + return + MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + +#endif +#endif + +#define QK4_0 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(wsp_ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + wsp_ggml_fp16_t m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == 2 * sizeof(wsp_ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(wsp_ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + wsp_ggml_fp16_t m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(wsp_ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(wsp_ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +typedef struct { + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); + +// reference implementation for deterministic creation of model files +static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { + quantize_row_q4_0_reference(x, y, k); +} + +static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { + const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + y[i].m = WSP_GGML_FP32_TO_FP16(min); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { + quantize_row_q4_1_reference(x, y, k); +} + +static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10) >> 4) << (j + 0); + qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(qh)); + } +} + +static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { + quantize_row_q5_0_reference(x, y, k); +} + +static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { + const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 5) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + y[i].m = WSP_GGML_FP32_TO_FP16(min); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10) >> 4) << (j + 0); + qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { + quantize_row_q5_1_reference(x, y, k); +} + +// reference implementation for deterministic creation of model files +static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[i*QK8_0 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[i*QK8_0 + j]*id; + + y[i].qs[j] = roundf(x0); + } + } +} + +static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = WSP_GGML_FP32_TO_FP16(d); + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#else + // scalar + quantize_row_q8_0_reference(x, y, k); +#endif +} + +// reference implementation for deterministic creation of model files +static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { + assert(QK8_1 == 32); + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_1; j++) { + const float v = x[i*QK8_1 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + int sum = 0; + + for (int j = 0; j < QK8_1/2; ++j) { + const float v0 = x[i*QK8_1 + j]*id; + const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id; + + y[i].qs[ j] = roundf(v0); + y[i].qs[QK8_1/2 + j] = roundf(v1); + + sum += y[i].qs[ j]; + sum += y[i].qs[QK8_1/2 + j]; + } + + y[i].s = sum*d; + } +} + +static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); + } + + y[i].s = d * vaddvq_s32(accv); + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + v128_t accv = wasm_i32x4_splat(0); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + + accv = wasm_i32x4_add(accv, vi); + } + + y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) + + wasm_i32x4_extract_lane(accv, 1) + + wasm_i32x4_extract_lane(accv, 2) + + wasm_i32x4_extract_lane(accv, 3)); + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Compute the sum of the quants and set y[i].s + y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1)); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#else + // scalar + quantize_row_q8_1_reference(x, y, k); +#endif +} + +static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { + static const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + const float m = WSP_GGML_FP16_TO_FP32(x[i].m); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F); + const int x1 = (x[i].qs[j] >> 4); + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { + static const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + const float m = WSP_GGML_FP16_TO_FP32(x[i].m); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int x0 = (x[i].qs[j] & 0x0F) | xh_0; + const int x1 = (x[i].qs[j] >> 4) | xh_1; + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) { + static const int qk = QK8_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + const block_q8_0 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk; ++j) { + y[i*qk + j] = x[i].qs[j]*d; + } + } +} + +static void wsp_ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void wsp_ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void wsp_ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void wsp_ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void wsp_ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); + +static const quantize_fns_t quantize_fns[WSP_GGML_TYPE_COUNT] = { + [WSP_GGML_TYPE_Q4_0] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_0, + .quantize_row_q = quantize_row_q4_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = wsp_ggml_vec_dot_q4_0_q8_0, + .vec_dot_type = WSP_GGML_TYPE_Q8_0, + }, + [WSP_GGML_TYPE_Q4_1] = { + .dequantize_row_q = (dequantize_row_q_t)dequantize_row_q4_1, + .quantize_row_q = quantize_row_q4_1, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, + .quantize_row_q_dot = quantize_row_q8_1, + .vec_dot_q = wsp_ggml_vec_dot_q4_1_q8_1, + .vec_dot_type = WSP_GGML_TYPE_Q8_1, + }, + [WSP_GGML_TYPE_Q5_0] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_0, + .quantize_row_q = quantize_row_q5_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = wsp_ggml_vec_dot_q5_0_q8_0, + .vec_dot_type = WSP_GGML_TYPE_Q8_0, + }, + [WSP_GGML_TYPE_Q5_1] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_1, + .quantize_row_q = quantize_row_q5_1, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference, + .quantize_row_q_dot = quantize_row_q8_1, + .vec_dot_q = wsp_ggml_vec_dot_q5_1_q8_1, + .vec_dot_type = WSP_GGML_TYPE_Q8_1, + }, + [WSP_GGML_TYPE_Q8_0] = { + .dequantize_row_q = dequantize_row_q8_0, + .quantize_row_q = quantize_row_q8_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = wsp_ggml_vec_dot_q8_0_q8_0, + .vec_dot_type = WSP_GGML_TYPE_Q8_0, + }, + [WSP_GGML_TYPE_Q8_1] = { + .dequantize_row_q = NULL, // TODO + .quantize_row_q = quantize_row_q8_1, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference, + .quantize_row_q_dot = quantize_row_q8_1, + .vec_dot_q = NULL, // TODO + .vec_dot_type = WSP_GGML_TYPE_Q8_1, + }, +#ifdef WSP_GGML_USE_K_QUANTS + [WSP_GGML_TYPE_Q2_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q2_K, + .quantize_row_q = quantize_row_q2_K, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference, + .quantize_row_q_dot = quantize_row_q8_K, + .vec_dot_q = wsp_ggml_vec_dot_q2_K_q8_K, + .vec_dot_type = WSP_GGML_TYPE_Q8_K, + }, + [WSP_GGML_TYPE_Q3_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_K, + .quantize_row_q = quantize_row_q3_K, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference, + .quantize_row_q_dot = quantize_row_q8_K, + .vec_dot_q = wsp_ggml_vec_dot_q3_K_q8_K, + .vec_dot_type = WSP_GGML_TYPE_Q8_K, + }, + [WSP_GGML_TYPE_Q4_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_K, + .quantize_row_q = quantize_row_q4_K, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference, + .quantize_row_q_dot = quantize_row_q8_K, + .vec_dot_q = wsp_ggml_vec_dot_q4_K_q8_K, + .vec_dot_type = WSP_GGML_TYPE_Q8_K, + }, + [WSP_GGML_TYPE_Q5_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_K, + .quantize_row_q = quantize_row_q5_K, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference, + .quantize_row_q_dot = quantize_row_q8_K, + .vec_dot_q = wsp_ggml_vec_dot_q5_K_q8_K, + .vec_dot_type = WSP_GGML_TYPE_Q8_K, + }, + [WSP_GGML_TYPE_Q6_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q6_K, + .quantize_row_q = quantize_row_q6_K, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference, + .quantize_row_q_dot = quantize_row_q8_K, + .vec_dot_q = wsp_ggml_vec_dot_q6_K_q8_K, + .vec_dot_type = WSP_GGML_TYPE_Q8_K, + }, +#endif +}; + +// For internal test use +quantize_fns_t wsp_ggml_internal_get_quantize_fn(size_t i) { + WSP_GGML_ASSERT(i < WSP_GGML_TYPE_COUNT); + return quantize_fns[i]; +} + + +// +// simd mappings +// + +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// WSP_GGML_F32_STEP / WSP_GGML_F16_STEP +// number of elements to process in a single step +// +// WSP_GGML_F32_EPR / WSP_GGML_F16_EPR +// number of elements to fit in a single register +// + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + +#define WSP_GGML_SIMD + +// F32 NEON + +#define WSP_GGML_F32_STEP 16 +#define WSP_GGML_F32_EPR 4 + +#define WSP_GGML_F32x4 float32x4_t +#define WSP_GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define WSP_GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define WSP_GGML_F32x4_LOAD vld1q_f32 +#define WSP_GGML_F32x4_STORE vst1q_f32 +#define WSP_GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define WSP_GGML_F32x4_ADD vaddq_f32 +#define WSP_GGML_F32x4_MUL vmulq_f32 +#define WSP_GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#define WSP_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = WSP_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ + } \ + res = WSP_GGML_F32x4_REDUCE_ONE(x[0]); \ +} + +#define WSP_GGML_F32_VEC WSP_GGML_F32x4 +#define WSP_GGML_F32_VEC_ZERO WSP_GGML_F32x4_ZERO +#define WSP_GGML_F32_VEC_SET1 WSP_GGML_F32x4_SET1 +#define WSP_GGML_F32_VEC_LOAD WSP_GGML_F32x4_LOAD +#define WSP_GGML_F32_VEC_STORE WSP_GGML_F32x4_STORE +#define WSP_GGML_F32_VEC_FMA WSP_GGML_F32x4_FMA +#define WSP_GGML_F32_VEC_ADD WSP_GGML_F32x4_ADD +#define WSP_GGML_F32_VEC_MUL WSP_GGML_F32x4_MUL +#define WSP_GGML_F32_VEC_REDUCE WSP_GGML_F32x4_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define WSP_GGML_F16_STEP 32 + #define WSP_GGML_F16_EPR 8 + + #define WSP_GGML_F16x8 float16x8_t + #define WSP_GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define WSP_GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define WSP_GGML_F16x8_LOAD vld1q_f16 + #define WSP_GGML_F16x8_STORE vst1q_f16 + #define WSP_GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define WSP_GGML_F16x8_ADD vaddq_f16 + #define WSP_GGML_F16x8_MUL vmulq_f16 + #define WSP_GGML_F16x8_REDUCE(res, x) \ + { \ + int offset = WSP_GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ + res = (wsp_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ + } + + #define WSP_GGML_F16_VEC WSP_GGML_F16x8 + #define WSP_GGML_F16_VEC_ZERO WSP_GGML_F16x8_ZERO + #define WSP_GGML_F16_VEC_SET1 WSP_GGML_F16x8_SET1 + #define WSP_GGML_F16_VEC_LOAD(p, i) WSP_GGML_F16x8_LOAD(p) + #define WSP_GGML_F16_VEC_STORE(p, r, i) WSP_GGML_F16x8_STORE(p, r[i]) + #define WSP_GGML_F16_VEC_FMA WSP_GGML_F16x8_FMA + #define WSP_GGML_F16_VEC_ADD WSP_GGML_F16x8_ADD + #define WSP_GGML_F16_VEC_MUL WSP_GGML_F16x8_MUL + #define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define WSP_GGML_F16_STEP 16 + #define WSP_GGML_F16_EPR 4 + + #define WSP_GGML_F32Cx4 float32x4_t + #define WSP_GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define WSP_GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define WSP_GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) + #define WSP_GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define WSP_GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define WSP_GGML_F32Cx4_ADD vaddq_f32 + #define WSP_GGML_F32Cx4_MUL vmulq_f32 + #define WSP_GGML_F32Cx4_REDUCE WSP_GGML_F32x4_REDUCE + + #define WSP_GGML_F16_VEC WSP_GGML_F32Cx4 + #define WSP_GGML_F16_VEC_ZERO WSP_GGML_F32Cx4_ZERO + #define WSP_GGML_F16_VEC_SET1 WSP_GGML_F32Cx4_SET1 + #define WSP_GGML_F16_VEC_LOAD(p, i) WSP_GGML_F32Cx4_LOAD(p) + #define WSP_GGML_F16_VEC_STORE(p, r, i) WSP_GGML_F32Cx4_STORE(p, r[i]) + #define WSP_GGML_F16_VEC_FMA WSP_GGML_F32Cx4_FMA + #define WSP_GGML_F16_VEC_ADD WSP_GGML_F32Cx4_ADD + #define WSP_GGML_F16_VEC_MUL WSP_GGML_F32Cx4_MUL + #define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx4_REDUCE +#endif + +#elif defined(__AVX__) + +#define WSP_GGML_SIMD + +// F32 AVX + +#define WSP_GGML_F32_STEP 32 +#define WSP_GGML_F32_EPR 8 + +#define WSP_GGML_F32x8 __m256 +#define WSP_GGML_F32x8_ZERO _mm256_setzero_ps() +#define WSP_GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define WSP_GGML_F32x8_LOAD _mm256_loadu_ps +#define WSP_GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define WSP_GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define WSP_GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define WSP_GGML_F32x8_ADD _mm256_add_ps +#define WSP_GGML_F32x8_MUL _mm256_mul_ps +#define WSP_GGML_F32x8_REDUCE(res, x) \ +{ \ + int offset = WSP_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} +// TODO: is this optimal ? + +#define WSP_GGML_F32_VEC WSP_GGML_F32x8 +#define WSP_GGML_F32_VEC_ZERO WSP_GGML_F32x8_ZERO +#define WSP_GGML_F32_VEC_SET1 WSP_GGML_F32x8_SET1 +#define WSP_GGML_F32_VEC_LOAD WSP_GGML_F32x8_LOAD +#define WSP_GGML_F32_VEC_STORE WSP_GGML_F32x8_STORE +#define WSP_GGML_F32_VEC_FMA WSP_GGML_F32x8_FMA +#define WSP_GGML_F32_VEC_ADD WSP_GGML_F32x8_ADD +#define WSP_GGML_F32_VEC_MUL WSP_GGML_F32x8_MUL +#define WSP_GGML_F32_VEC_REDUCE WSP_GGML_F32x8_REDUCE + +// F16 AVX + +#define WSP_GGML_F16_STEP 32 +#define WSP_GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead + +#define WSP_GGML_F32Cx8 __m256 +#define WSP_GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define WSP_GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) + +#if defined(__F16C__) +// the _mm256_cvt intrinsics require F16C +#define WSP_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define WSP_GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#else +static inline __m256 __avx_f32cx8_load(wsp_ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = WSP_GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline void __avx_f32cx8_store(wsp_ggml_fp16_t *x, __m256 y) { + float arr[8]; + + _mm256_storeu_ps(arr, y); + + for (int i = 0; i < 8; i++) + x[i] = WSP_GGML_FP32_TO_FP16(arr[i]); +} +#define WSP_GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define WSP_GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) +#endif + +#define WSP_GGML_F32Cx8_FMA WSP_GGML_F32x8_FMA +#define WSP_GGML_F32Cx8_ADD _mm256_add_ps +#define WSP_GGML_F32Cx8_MUL _mm256_mul_ps +#define WSP_GGML_F32Cx8_REDUCE WSP_GGML_F32x8_REDUCE + +#define WSP_GGML_F16_VEC WSP_GGML_F32Cx8 +#define WSP_GGML_F16_VEC_ZERO WSP_GGML_F32Cx8_ZERO +#define WSP_GGML_F16_VEC_SET1 WSP_GGML_F32Cx8_SET1 +#define WSP_GGML_F16_VEC_LOAD(p, i) WSP_GGML_F32Cx8_LOAD(p) +#define WSP_GGML_F16_VEC_STORE(p, r, i) WSP_GGML_F32Cx8_STORE(p, r[i]) +#define WSP_GGML_F16_VEC_FMA WSP_GGML_F32Cx8_FMA +#define WSP_GGML_F16_VEC_ADD WSP_GGML_F32Cx8_ADD +#define WSP_GGML_F16_VEC_MUL WSP_GGML_F32Cx8_MUL +#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx8_REDUCE + +#elif defined(__POWER9_VECTOR__) + +#define WSP_GGML_SIMD + +// F32 POWER9 + +#define WSP_GGML_F32_STEP 32 +#define WSP_GGML_F32_EPR 4 + +#define WSP_GGML_F32x4 vector float +#define WSP_GGML_F32x4_ZERO 0.0f +#define WSP_GGML_F32x4_SET1 vec_splats +#define WSP_GGML_F32x4_LOAD(p) vec_xl(0, p) +#define WSP_GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) +#define WSP_GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define WSP_GGML_F32x4_ADD vec_add +#define WSP_GGML_F32x4_MUL vec_mul +#define WSP_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = WSP_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define WSP_GGML_F32_VEC WSP_GGML_F32x4 +#define WSP_GGML_F32_VEC_ZERO WSP_GGML_F32x4_ZERO +#define WSP_GGML_F32_VEC_SET1 WSP_GGML_F32x4_SET1 +#define WSP_GGML_F32_VEC_LOAD WSP_GGML_F32x4_LOAD +#define WSP_GGML_F32_VEC_STORE WSP_GGML_F32x4_STORE +#define WSP_GGML_F32_VEC_FMA WSP_GGML_F32x4_FMA +#define WSP_GGML_F32_VEC_ADD WSP_GGML_F32x4_ADD +#define WSP_GGML_F32_VEC_MUL WSP_GGML_F32x4_MUL +#define WSP_GGML_F32_VEC_REDUCE WSP_GGML_F32x4_REDUCE + +// F16 POWER9 +#define WSP_GGML_F16_STEP WSP_GGML_F32_STEP +#define WSP_GGML_F16_EPR WSP_GGML_F32_EPR +#define WSP_GGML_F16_VEC WSP_GGML_F32x4 +#define WSP_GGML_F16_VEC_ZERO WSP_GGML_F32x4_ZERO +#define WSP_GGML_F16_VEC_SET1 WSP_GGML_F32x4_SET1 +#define WSP_GGML_F16_VEC_FMA WSP_GGML_F32x4_FMA +#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32x4_REDUCE +// Use vec_xl, not vec_ld, in case the load address is not aligned. +#define WSP_GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ + vec_extract_fp32_from_shorth(vec_xl(0, p - WSP_GGML_F16_EPR)) : \ + vec_extract_fp32_from_shortl(vec_xl(0, p)) +#define WSP_GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +#define WSP_GGML_F16_VEC_STORE(p, r, i) \ + if (i & 0x1) \ + vec_xst(vec_pack_to_short_fp32(r[i - WSP_GGML_ENDIAN_BYTE(1)], \ + r[i - WSP_GGML_ENDIAN_BYTE(0)]), \ + 0, p - WSP_GGML_F16_EPR) + +#elif defined(__wasm_simd128__) + +#define WSP_GGML_SIMD + +// F32 WASM + +#define WSP_GGML_F32_STEP 16 +#define WSP_GGML_F32_EPR 4 + +#define WSP_GGML_F32x4 v128_t +#define WSP_GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define WSP_GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define WSP_GGML_F32x4_LOAD wasm_v128_load +#define WSP_GGML_F32x4_STORE wasm_v128_store +#define WSP_GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define WSP_GGML_F32x4_ADD wasm_f32x4_add +#define WSP_GGML_F32x4_MUL wasm_f32x4_mul +#define WSP_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = WSP_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define WSP_GGML_F32_VEC WSP_GGML_F32x4 +#define WSP_GGML_F32_VEC_ZERO WSP_GGML_F32x4_ZERO +#define WSP_GGML_F32_VEC_SET1 WSP_GGML_F32x4_SET1 +#define WSP_GGML_F32_VEC_LOAD WSP_GGML_F32x4_LOAD +#define WSP_GGML_F32_VEC_STORE WSP_GGML_F32x4_STORE +#define WSP_GGML_F32_VEC_FMA WSP_GGML_F32x4_FMA +#define WSP_GGML_F32_VEC_ADD WSP_GGML_F32x4_ADD +#define WSP_GGML_F32_VEC_MUL WSP_GGML_F32x4_MUL +#define WSP_GGML_F32_VEC_REDUCE WSP_GGML_F32x4_REDUCE + +// F16 WASM + +#define WSP_GGML_F16_STEP 16 +#define WSP_GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const wsp_ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = WSP_GGML_FP16_TO_FP32(p[0]); + tmp[1] = WSP_GGML_FP16_TO_FP32(p[1]); + tmp[2] = WSP_GGML_FP16_TO_FP32(p[2]); + tmp[3] = WSP_GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(wsp_ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = WSP_GGML_FP32_TO_FP16(tmp[0]); + p[1] = WSP_GGML_FP32_TO_FP16(tmp[1]); + p[2] = WSP_GGML_FP32_TO_FP16(tmp[2]); + p[3] = WSP_GGML_FP32_TO_FP16(tmp[3]); +} + +#define WSP_GGML_F16x4 v128_t +#define WSP_GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define WSP_GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define WSP_GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define WSP_GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define WSP_GGML_F16x4_FMA WSP_GGML_F32x4_FMA +#define WSP_GGML_F16x4_ADD wasm_f32x4_add +#define WSP_GGML_F16x4_MUL wasm_f32x4_mul +#define WSP_GGML_F16x4_REDUCE(res, x) \ +{ \ + int offset = WSP_GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define WSP_GGML_F16_VEC WSP_GGML_F16x4 +#define WSP_GGML_F16_VEC_ZERO WSP_GGML_F16x4_ZERO +#define WSP_GGML_F16_VEC_SET1 WSP_GGML_F16x4_SET1 +#define WSP_GGML_F16_VEC_LOAD(p, i) WSP_GGML_F16x4_LOAD(p) +#define WSP_GGML_F16_VEC_STORE(p, r, i) WSP_GGML_F16x4_STORE(p, r[i]) +#define WSP_GGML_F16_VEC_FMA WSP_GGML_F16x4_FMA +#define WSP_GGML_F16_VEC_ADD WSP_GGML_F16x4_ADD +#define WSP_GGML_F16_VEC_MUL WSP_GGML_F16x4_MUL +#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F16x4_REDUCE + +#elif defined(__SSE3__) + +#define WSP_GGML_SIMD + +// F32 SSE + +#define WSP_GGML_F32_STEP 32 +#define WSP_GGML_F32_EPR 4 + +#define WSP_GGML_F32x4 __m128 +#define WSP_GGML_F32x4_ZERO _mm_setzero_ps() +#define WSP_GGML_F32x4_SET1(x) _mm_set1_ps(x) +#define WSP_GGML_F32x4_LOAD _mm_loadu_ps +#define WSP_GGML_F32x4_STORE _mm_storeu_ps +#if defined(__FMA__) + // TODO: Does this work? + #define WSP_GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) +#else + #define WSP_GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) +#endif +#define WSP_GGML_F32x4_ADD _mm_add_ps +#define WSP_GGML_F32x4_MUL _mm_mul_ps +#define WSP_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = WSP_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ +} +// TODO: is this optimal ? + +#define WSP_GGML_F32_VEC WSP_GGML_F32x4 +#define WSP_GGML_F32_VEC_ZERO WSP_GGML_F32x4_ZERO +#define WSP_GGML_F32_VEC_SET1 WSP_GGML_F32x4_SET1 +#define WSP_GGML_F32_VEC_LOAD WSP_GGML_F32x4_LOAD +#define WSP_GGML_F32_VEC_STORE WSP_GGML_F32x4_STORE +#define WSP_GGML_F32_VEC_FMA WSP_GGML_F32x4_FMA +#define WSP_GGML_F32_VEC_ADD WSP_GGML_F32x4_ADD +#define WSP_GGML_F32_VEC_MUL WSP_GGML_F32x4_MUL +#define WSP_GGML_F32_VEC_REDUCE WSP_GGML_F32x4_REDUCE + +// F16 SSE + +#define WSP_GGML_F16_STEP 32 +#define WSP_GGML_F16_EPR 4 + +static inline __m128 __sse_f16x4_load(wsp_ggml_fp16_t *x) { + float tmp[4]; + + tmp[0] = WSP_GGML_FP16_TO_FP32(x[0]); + tmp[1] = WSP_GGML_FP16_TO_FP32(x[1]); + tmp[2] = WSP_GGML_FP16_TO_FP32(x[2]); + tmp[3] = WSP_GGML_FP16_TO_FP32(x[3]); + + return _mm_loadu_ps(tmp); +} + +static inline void __sse_f16x4_store(wsp_ggml_fp16_t *x, __m128 y) { + float arr[4]; + + _mm_storeu_ps(arr, y); + + x[0] = WSP_GGML_FP32_TO_FP16(arr[0]); + x[1] = WSP_GGML_FP32_TO_FP16(arr[1]); + x[2] = WSP_GGML_FP32_TO_FP16(arr[2]); + x[3] = WSP_GGML_FP32_TO_FP16(arr[3]); +} + +#define WSP_GGML_F32Cx4 __m128 +#define WSP_GGML_F32Cx4_ZERO _mm_setzero_ps() +#define WSP_GGML_F32Cx4_SET1(x) _mm_set1_ps(x) +#define WSP_GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) +#define WSP_GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) +#define WSP_GGML_F32Cx4_FMA WSP_GGML_F32x4_FMA +#define WSP_GGML_F32Cx4_ADD _mm_add_ps +#define WSP_GGML_F32Cx4_MUL _mm_mul_ps +#define WSP_GGML_F32Cx4_REDUCE WSP_GGML_F32x4_REDUCE + +#define WSP_GGML_F16_VEC WSP_GGML_F32Cx4 +#define WSP_GGML_F16_VEC_ZERO WSP_GGML_F32Cx4_ZERO +#define WSP_GGML_F16_VEC_SET1 WSP_GGML_F32Cx4_SET1 +#define WSP_GGML_F16_VEC_LOAD(p, i) WSP_GGML_F32Cx4_LOAD(p) +#define WSP_GGML_F16_VEC_STORE(p, r, i) WSP_GGML_F32Cx4_STORE(p, r[i]) +#define WSP_GGML_F16_VEC_FMA WSP_GGML_F32Cx4_FMA +#define WSP_GGML_F16_VEC_ADD WSP_GGML_F32Cx4_ADD +#define WSP_GGML_F16_VEC_MUL WSP_GGML_F32Cx4_MUL +#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx4_REDUCE + +#endif + +// WSP_GGML_F32_ARR / WSP_GGML_F16_ARR +// number of registers to use per step +#ifdef WSP_GGML_SIMD +#define WSP_GGML_F32_ARR (WSP_GGML_F32_STEP/WSP_GGML_F32_EPR) +#define WSP_GGML_F16_ARR (WSP_GGML_F16_STEP/WSP_GGML_F16_EPR) +#endif + +// +// fundamental operations +// + +inline static void wsp_ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void wsp_ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void wsp_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void wsp_ggml_vec_set_f16(const int n, wsp_ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void wsp_ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } +inline static void wsp_ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void wsp_ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void wsp_ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void wsp_ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void wsp_ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void wsp_ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void wsp_ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void wsp_ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } + +inline static void wsp_ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { +#ifdef WSP_GGML_SIMD + float sumf = 0.0f; + const int np = (n & ~(WSP_GGML_F32_STEP - 1)); + + WSP_GGML_F32_VEC sum[WSP_GGML_F32_ARR] = { WSP_GGML_F32_VEC_ZERO }; + + WSP_GGML_F32_VEC ax[WSP_GGML_F32_ARR]; + WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR]; + + for (int i = 0; i < np; i += WSP_GGML_F32_STEP) { + for (int j = 0; j < WSP_GGML_F32_ARR; j++) { + ax[j] = WSP_GGML_F32_VEC_LOAD(x + i + j*WSP_GGML_F32_EPR); + ay[j] = WSP_GGML_F32_VEC_LOAD(y + i + j*WSP_GGML_F32_EPR); + + sum[j] = WSP_GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + WSP_GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } +#else + // scalar + wsp_ggml_float sumf = 0.0; + for (int i = 0; i < n; ++i) { + sumf += (wsp_ggml_float)(x[i]*y[i]); + } +#endif + + *s = sumf; +} + +inline static void wsp_ggml_vec_dot_f16(const int n, float * restrict s, wsp_ggml_fp16_t * restrict x, wsp_ggml_fp16_t * restrict y) { + wsp_ggml_float sumf = 0.0; + +#if defined(WSP_GGML_SIMD) + const int np = (n & ~(WSP_GGML_F16_STEP - 1)); + + WSP_GGML_F16_VEC sum[WSP_GGML_F16_ARR] = { WSP_GGML_F16_VEC_ZERO }; + + WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR]; + WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR]; + + for (int i = 0; i < np; i += WSP_GGML_F16_STEP) { + for (int j = 0; j < WSP_GGML_F16_ARR; j++) { + ax[j] = WSP_GGML_F16_VEC_LOAD(x + i + j*WSP_GGML_F16_EPR, j); + ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j); + + sum[j] = WSP_GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + WSP_GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[i])*WSP_GGML_FP16_TO_FP32(y[i])); + } +#else + for (int i = 0; i < n; ++i) { + sumf += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[i])*WSP_GGML_FP16_TO_FP32(y[i])); + } +#endif + + *s = sumf; +} + +static void wsp_ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nb % 2 == 0); + + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i += 2) { + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d) ); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d) ); + + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx = _mm_and_si128(lowMask, tmp); + __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx, by); + + bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); + by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx, by); + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); + + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + // First round without accumulation + { + _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[0].d) * WSP_GGML_FP16_TO_FP32(y[0].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[1].d) * WSP_GGML_FP16_TO_FP32(y[1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + acc_0 = _mm_mul_ps( d_0_1, p0 ); + acc_1 = _mm_mul_ps( d_0_1, p1 ); + acc_2 = _mm_mul_ps( d_2_3, p2 ); + acc_3 = _mm_mul_ps( d_2_3, p3 ); + } + + // Main loop + for (int i = 2; i < nb; i+=2) { + _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[i + 1].d) * WSP_GGML_FP16_TO_FP32(y[i + 1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F) - 8; + const int v1 = (x[i].qs[j] >> 4) - 8; + + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } + + sumf += sumi*WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d); + } + + *s = sumf; +#endif +} + +static void wsp_ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nb % 2 == 0); + + const block_q4_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + + // TODO: add WASM SIMD +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs = 0; + + for (int i = 0; i < nb; i += 2) { + const block_q4_1 * restrict x0 = &x[i + 0]; + const block_q4_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i + 0]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + summs += WSP_GGML_FP16_TO_FP32(x0->m) * y0->s + WSP_GGML_FP16_TO_FP32(x1->m) * y1->s; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + // Main loop + for (int i = 0; i < nb; ++i) { + const float d0 = WSP_GGML_FP16_TO_FP32(x[i].d); + const float d1 = y[i].d; + + summs += WSP_GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + + const __m256 xy = mul_sum_us8_pairs_float(bx, by); + + // Accumulate d0*d1*x*y +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d0d1, xy, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif + } + + *s = hsum_float_8(acc) + summs; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F); + const int v1 = (x[i].qs[j] >> 4); + + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } + + sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#endif +} + +static void wsp_ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nb % 2 == 0); + assert(qk == QK5_0); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + for (int i = 0; i < nb; i += 2) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q5_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + // extract the 5th bit via lookup table ((!b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); + const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( + wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(WSP_GGML_FP16_TO_FP32(x0->d) * WSP_GGML_FP16_TO_FP32(y0->d)))); + } + + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d)); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + bx = _mm256_or_si256(bx, bxhi); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8((char)0xF0); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d)); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_andnot_si128(bxhil, mask); + bxhih = _mm_andnot_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx); + __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx = MM256_SET_M128I(bxh, bxl); + + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); + } + + *s = hsum_float_8(acc); +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } + + sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)) * sumi; + } + + *s = sumf; +#endif +} + +static void wsp_ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nb % 2 == 0); + assert(qk == QK5_1); + + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + for (int i = 0; i < nb; i += 2) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q5_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + summs0 += WSP_GGML_FP16_TO_FP32(x0->m) * y0->s; + summs1 += WSP_GGML_FP16_TO_FP32(x1->m) * y1->s; + + // extract the 5th bit via lookup table ((b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit + const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + float summs = 0.0f; + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; + + summs += WSP_GGML_FP16_TO_FP32(x0->m) * y0->s; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit + const v128_t v0lf = wasm_v128_or(v0l, qhl); + const v128_t v0hf = wasm_v128_or(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(WSP_GGML_FP16_TO_FP32(x0->d) * y0->d))); + } + + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d)); + + summs += WSP_GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + bx = _mm256_or_si256(bx, bxhi); + + const __m256 dy = _mm256_set1_ps(y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx, by); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8(0x10); + + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d)); + + summs += WSP_GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_and_si128(bxhil, mask); + bxhih = _mm_and_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx); + __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx = MM256_SET_M128I(bxh, bxl); + + const __m256 dy = _mm256_set1_ps(y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx, by); + + acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + } + + *s = hsum_float_8(acc) + summs; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; + + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } + + sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#endif +} + +static void wsp_ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nb % 2 == 0); + + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); + +#else + const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); + const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); + const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); + const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); + + const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); + const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); + const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); + const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); + + const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); + const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); + const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d)); + __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + // Multiply q with scale and accumulate +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d, q, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); +#endif + } + + *s = hsum_float_8(acc); +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[i].qs[j]*y[i].qs[j]; + } + + sumf += sumi*(WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)); + } + + *s = sumf; +#endif +} + +// compute WSP_GGML_VEC_DOT_UNROLL dot products at once +// xs - x row stride in bytes +inline static void wsp_ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, wsp_ggml_fp16_t * restrict y) { + wsp_ggml_float sumf[WSP_GGML_VEC_DOT_UNROLL] = { 0.0 }; + + wsp_ggml_fp16_t * restrict x[WSP_GGML_VEC_DOT_UNROLL]; + + for (int i = 0; i < WSP_GGML_VEC_DOT_UNROLL; ++i) { + x[i] = (wsp_ggml_fp16_t *) ((char *) xv + i*xs); + } + +#if defined(WSP_GGML_SIMD) + const int np = (n & ~(WSP_GGML_F16_STEP - 1)); + + WSP_GGML_F16_VEC sum[WSP_GGML_VEC_DOT_UNROLL][WSP_GGML_F16_ARR] = { { WSP_GGML_F16_VEC_ZERO } }; + + WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR]; + WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR]; + + for (int i = 0; i < np; i += WSP_GGML_F16_STEP) { + for (int j = 0; j < WSP_GGML_F16_ARR; j++) { + ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j); + + for (int k = 0; k < WSP_GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = WSP_GGML_F16_VEC_LOAD(x[k] + i + j*WSP_GGML_F16_EPR, j); + + sum[k][j] = WSP_GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } + } + } + + // reduce sum0..sum3 to sum0 + for (int k = 0; k < WSP_GGML_VEC_DOT_UNROLL; ++k) { + WSP_GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < WSP_GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[j][i])*WSP_GGML_FP16_TO_FP32(y[i])); + } + } +#else + for (int i = 0; i < n; ++i) { + for (int j = 0; j < WSP_GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[j][i])*WSP_GGML_FP16_TO_FP32(y[i])); + } + } +#endif + + for (int i = 0; i < WSP_GGML_VEC_DOT_UNROLL; ++i) { + s[i] = sumf[i]; + } +} + +inline static void wsp_ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { +#if defined(WSP_GGML_SIMD) + const int np = (n & ~(WSP_GGML_F32_STEP - 1)); + + WSP_GGML_F32_VEC vx = WSP_GGML_F32_VEC_SET1(v); + + WSP_GGML_F32_VEC ax[WSP_GGML_F32_ARR]; + WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR]; + + for (int i = 0; i < np; i += WSP_GGML_F32_STEP) { + for (int j = 0; j < WSP_GGML_F32_ARR; j++) { + ax[j] = WSP_GGML_F32_VEC_LOAD(x + i + j*WSP_GGML_F32_EPR); + ay[j] = WSP_GGML_F32_VEC_LOAD(y + i + j*WSP_GGML_F32_EPR); + ay[j] = WSP_GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + WSP_GGML_F32_VEC_STORE(y + i + j*WSP_GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += x[i]*v; + } +#endif +} + +//inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(WSP_GGML_SIMD) + const int np = (n & ~(WSP_GGML_F32_STEP - 1)); + + WSP_GGML_F32_VEC vx = WSP_GGML_F32_VEC_SET1(v); + + WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR]; + + for (int i = 0; i < np; i += WSP_GGML_F32_STEP) { + for (int j = 0; j < WSP_GGML_F32_ARR; j++) { + ay[j] = WSP_GGML_F32_VEC_LOAD(y + i + j*WSP_GGML_F32_EPR); + ay[j] = WSP_GGML_F32_VEC_MUL(ay[j], vx); + + WSP_GGML_F32_VEC_STORE(y + i + j*WSP_GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= v; + } +#endif +} + +inline static void wsp_ggml_vec_norm_f32 (const int n, float * s, const float * x) { wsp_ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } +inline static void wsp_ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } +inline static void wsp_ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } +inline static void wsp_ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void wsp_ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } +inline static void wsp_ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } +inline static void wsp_ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void wsp_ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } +inline static void wsp_ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } +inline static void wsp_ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } + +static const float GELU_COEF_A = 0.044715f; +static const float GELU_QUICK_COEF = -1.702f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +inline static float wsp_ggml_gelu_f32(float x) { + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +inline static void wsp_ggml_vec_gelu_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = table_gelu_f16[i16[i]]; + } +} + +#ifdef WSP_GGML_GELU_FP16 +inline static void wsp_ggml_vec_gelu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + wsp_ggml_fp16_t fp16 = WSP_GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = WSP_GGML_FP16_TO_FP32(table_gelu_f16[t]); + } +} +#else +inline static void wsp_ggml_vec_gelu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = wsp_ggml_gelu_f32(x[i]); + } +} +#endif + +inline static float wsp_ggml_gelu_quick_f32(float x) { + return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); +} + +//inline static void wsp_ggml_vec_gelu_quick_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x) { +// const uint16_t * i16 = (const uint16_t *) x; +// for (int i = 0; i < n; ++i) { +// y[i] = table_gelu_quick_f16[i16[i]]; +// } +//} + +#ifdef WSP_GGML_GELU_QUICK_FP16 +inline static void wsp_ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + wsp_ggml_fp16_t fp16 = WSP_GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = WSP_GGML_FP16_TO_FP32(table_gelu_quick_f16[t]); + } +} +#else +inline static void wsp_ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = wsp_ggml_gelu_quick_f32(x[i]); + } +} +#endif + +// Sigmoid Linear Unit (SiLU) function +inline static float wsp_ggml_silu_f32(float x) { + return x/(1.0f + expf(-x)); +} + +//inline static void wsp_ggml_vec_silu_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x) { +// const uint16_t * i16 = (const uint16_t *) x; +// for (int i = 0; i < n; ++i) { +// y[i] = table_silu_f16[i16[i]]; +// } +//} + +#ifdef WSP_GGML_SILU_FP16 +inline static void wsp_ggml_vec_silu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + wsp_ggml_fp16_t fp16 = WSP_GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = WSP_GGML_FP16_TO_FP32(table_silu_f16[t]); + } +} +#else +inline static void wsp_ggml_vec_silu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = wsp_ggml_silu_f32(x[i]); + } +} +#endif + +inline static float wsp_ggml_silu_backward_f32(float x, float dy) { + const float s = 1.0f/(1.0f + expf(-x)); + return dy*s*(1.0f + x*(1.0f - s)); +} + +#ifdef WSP_GGML_SILU_FP16 +inline static void wsp_ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { + for (int i = 0; i < n; ++i) { + // we did not use x[i] to compute forward silu but its f16 equivalent + // take derivative at f16 of x[i]: + wsp_ggml_fp16_t fp16 = WSP_GGML_FP32_TO_FP16(x[i]); + float usedx = WSP_GGML_FP16_TO_FP32(fp16); + dx[i] = wsp_ggml_silu_backward_f32(usedx, dy[i]); + } +} +#else +inline static void wsp_ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { + for (int i = 0; i < n; ++i) { + dx[i] = wsp_ggml_silu_backward_f32(x[i], dy[i]); + } +} +#endif + +inline static void wsp_ggml_vec_sum_f32(const int n, float * s, const float * x) { +#ifndef WSP_GGML_USE_ACCELERATE + wsp_ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (wsp_ggml_float)x[i]; + } + *s = sum; +#else + vDSP_sve(x, 1, s, n); +#endif +} + +inline static void wsp_ggml_vec_sum_ggf(const int n, wsp_ggml_float * s, const float * x) { + wsp_ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (wsp_ggml_float)x[i]; + } + *s = sum; +} + +inline static void wsp_ggml_vec_max_f32(const int n, float * s, const float * x) { +#ifndef WSP_GGML_USE_ACCELERATE + float max = -INFINITY; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + *s = max; +#else + vDSP_maxv(x, 1, s, n); +#endif +} + +inline static void wsp_ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { + wsp_ggml_vec_norm_f32(n, s, x); + *s = 1.f/(*s); +} + +inline static void wsp_ggml_vec_argmax_f32(const int n, int * s, const float * x) { + float max = -INFINITY; + int idx = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + if (max == x[i]) { idx = i; } + } + *s = idx; +} + +// +// data types +// + +static const int WSP_GGML_BLCK_SIZE[WSP_GGML_TYPE_COUNT] = { + [WSP_GGML_TYPE_F32] = 1, + [WSP_GGML_TYPE_F16] = 1, + [WSP_GGML_TYPE_Q4_0] = QK4_0, + [WSP_GGML_TYPE_Q4_1] = QK4_1, + [WSP_GGML_TYPE_Q5_0] = QK5_0, + [WSP_GGML_TYPE_Q5_1] = QK5_1, + [WSP_GGML_TYPE_Q8_0] = QK8_0, + [WSP_GGML_TYPE_Q8_1] = QK8_1, +#ifdef WSP_GGML_USE_K_QUANTS + [WSP_GGML_TYPE_Q2_K] = QK_K, + [WSP_GGML_TYPE_Q3_K] = QK_K, + [WSP_GGML_TYPE_Q4_K] = QK_K, + [WSP_GGML_TYPE_Q5_K] = QK_K, + [WSP_GGML_TYPE_Q6_K] = QK_K, + [WSP_GGML_TYPE_Q8_K] = QK_K, +#endif + [WSP_GGML_TYPE_I8] = 1, + [WSP_GGML_TYPE_I16] = 1, + [WSP_GGML_TYPE_I32] = 1, +}; +static_assert(WSP_GGML_TYPE_COUNT == 19, "WSP_GGML_BLCK_SIZE is outdated"); + +static const size_t WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_COUNT] = { + [WSP_GGML_TYPE_F32] = sizeof(float), + [WSP_GGML_TYPE_F16] = sizeof(wsp_ggml_fp16_t), + [WSP_GGML_TYPE_Q4_0] = sizeof(block_q4_0), + [WSP_GGML_TYPE_Q4_1] = sizeof(block_q4_1), + [WSP_GGML_TYPE_Q5_0] = sizeof(block_q5_0), + [WSP_GGML_TYPE_Q5_1] = sizeof(block_q5_1), + [WSP_GGML_TYPE_Q8_0] = sizeof(block_q8_0), + [WSP_GGML_TYPE_Q8_1] = sizeof(block_q8_1), +#ifdef WSP_GGML_USE_K_QUANTS + [WSP_GGML_TYPE_Q2_K] = sizeof(block_q2_K), + [WSP_GGML_TYPE_Q3_K] = sizeof(block_q3_K), + [WSP_GGML_TYPE_Q4_K] = sizeof(block_q4_K), + [WSP_GGML_TYPE_Q5_K] = sizeof(block_q5_K), + [WSP_GGML_TYPE_Q6_K] = sizeof(block_q6_K), + [WSP_GGML_TYPE_Q8_K] = sizeof(block_q8_K), +#endif + [WSP_GGML_TYPE_I8] = sizeof(int8_t), + [WSP_GGML_TYPE_I16] = sizeof(int16_t), + [WSP_GGML_TYPE_I32] = sizeof(int32_t), +}; +static_assert(WSP_GGML_TYPE_COUNT == 19, "WSP_GGML_TYPE_SIZE is outdated"); + + +static const char * WSP_GGML_TYPE_NAME[WSP_GGML_TYPE_COUNT] = { + [WSP_GGML_TYPE_F32] = "f32", + [WSP_GGML_TYPE_F16] = "f16", + [WSP_GGML_TYPE_Q4_0] = "q4_0", + [WSP_GGML_TYPE_Q4_1] = "q4_1", + [WSP_GGML_TYPE_Q5_0] = "q5_0", + [WSP_GGML_TYPE_Q5_1] = "q5_1", + [WSP_GGML_TYPE_Q8_0] = "q8_0", + [WSP_GGML_TYPE_Q8_1] = "q8_1", + [WSP_GGML_TYPE_Q2_K] = "q2_K", + [WSP_GGML_TYPE_Q3_K] = "q3_K", + [WSP_GGML_TYPE_Q4_K] = "q4_K", + [WSP_GGML_TYPE_Q5_K] = "q5_K", + [WSP_GGML_TYPE_Q6_K] = "q6_K", + [WSP_GGML_TYPE_Q8_K] = "q8_K", + [WSP_GGML_TYPE_I8] = "i8", + [WSP_GGML_TYPE_I16] = "i16", + [WSP_GGML_TYPE_I32] = "i32", +}; +static_assert(WSP_GGML_TYPE_COUNT == 19, "WSP_GGML_TYPE_NAME is outdated"); + +static bool WSP_GGML_IS_QUANTIZED[WSP_GGML_TYPE_COUNT] = { + [WSP_GGML_TYPE_F32] = false, + [WSP_GGML_TYPE_F16] = false, + [WSP_GGML_TYPE_Q4_0] = true, + [WSP_GGML_TYPE_Q4_1] = true, + [WSP_GGML_TYPE_Q5_0] = true, + [WSP_GGML_TYPE_Q5_1] = true, + [WSP_GGML_TYPE_Q8_0] = true, + [WSP_GGML_TYPE_Q8_1] = true, + [WSP_GGML_TYPE_Q2_K] = true, + [WSP_GGML_TYPE_Q3_K] = true, + [WSP_GGML_TYPE_Q4_K] = true, + [WSP_GGML_TYPE_Q5_K] = true, + [WSP_GGML_TYPE_Q6_K] = true, + [WSP_GGML_TYPE_Q8_K] = true, + [WSP_GGML_TYPE_I8] = false, + [WSP_GGML_TYPE_I16] = false, + [WSP_GGML_TYPE_I32] = false, +}; +static_assert(WSP_GGML_TYPE_COUNT == 19, "WSP_GGML_IS_QUANTIZED is outdated"); + +static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = { + "NONE", + + "DUP", + "ADD", + "ADD1", + "ACC", + "SUB", + "MUL", + "DIV", + "SQR", + "SQRT", + "LOG", + "SUM", + "SUM_ROWS", + "MEAN", + "ARGMAX", + "REPEAT", + "REPEAT_BACK", + "ABS", + "SGN", + "NEG", + "STEP", + "TANH", + "ELU", + "RELU", + "GELU", + "GELU_QUICK", + "SILU", + "SILU_BACK", + "NORM", + "RMS_NORM", + "RMS_NORM_BACK", + + "MUL_MAT", + "OUT_PROD", + + "SCALE", + "SET", + "CPY", + "CONT", + "RESHAPE", + "VIEW", + "PERMUTE", + "TRANSPOSE", + "GET_ROWS", + "GET_ROWS_BACK", + "DIAG", + "DIAG_MASK_INF", + "DIAG_MASK_ZERO", + "SOFT_MAX", + "SOFT_MAX_BACK", + "ROPE", + "ROPE_BACK", + "ALIBI", + "CLAMP", + "CONV_1D", + "CONV_2D", + + "FLASH_ATTN", + "FLASH_FF", + "FLASH_ATTN_BACK", + "WIN_PART", + "WIN_UNPART", + + "MAP_UNARY", + "MAP_BINARY", + + "MAP_CUSTOM1", + "MAP_CUSTOM2", + "MAP_CUSTOM3", + + "CROSS_ENTROPY_LOSS", + "CROSS_ENTROPY_LOSS_BACK", +}; + +static_assert(WSP_GGML_OP_COUNT == 66, "WSP_GGML_OP_COUNT != 66"); + +static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = { + "none", + + "x", + "x+y", + "x+y", + "view(x,nb,offset)+=y->x", + "x-y", + "x*y", + "x/y", + "x^2", + "√x", + "log(x)", + "Σx", + "Σx_k", + "Σx/n", + "argmax(x)", + "repeat(x)", + "repeat_back(x)", + "abs(x)", + "sgn(x)", + "-x", + "step(x)", + "tanh(x)", + "elu(x)", + "relu(x)", + "gelu(x)", + "gelu_quick(x)", + "silu(x)", + "silu_back(x)", + "norm(x)", + "rms_norm(x)", + "rms_norm_back(x)", + + "X*Y", + "X*Y", + + "x*v", + "y-\\>view(x)", + "x-\\>y", + "cont(x)", + "reshape(x)", + "view(x)", + "permute(x)", + "transpose(x)", + "get_rows(x)", + "get_rows_back(x)", + "diag(x)", + "diag_mask_inf(x)", + "diag_mask_zero(x)", + "soft_max(x)", + "soft_max_back(x)", + "rope(x)", + "rope_back(x)", + "alibi(x)", + "clamp(x)", + "conv_1d(x)", + "conv_2d(x)", + + "flash_attn(x)", + "flash_ff(x)", + "flash_attn_back(x)", + "win_part(x)", + "win_unpart(x)", + + "f(x)", + "f(x,y)", + + "custom(x)", + "custom(x,y)", + "custom(x,y,z)", + + "cross_entropy_loss(x,y)", + "cross_entropy_loss_back(x,y)", +}; + +static_assert(WSP_GGML_OP_COUNT == 66, "WSP_GGML_OP_COUNT != 66"); + +static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN"); +static_assert(sizeof(struct wsp_ggml_tensor)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_tensor size must be a multiple of WSP_GGML_MEM_ALIGN"); + +// WARN: +// Mis-confguration can lead to problem that's hard to reason about: +// * At best it crash or talks nosense. +// * At worst it talks slightly difference but hard to perceive. +// +// An op has to enable INIT or FINALIZE when any of it's branch needs that pass. +// Take care about compile options (e.g., WSP_GGML_USE_xxx). +static bool WSP_GGML_OP_HAS_INIT [WSP_GGML_OP_COUNT] = { 0 }; +static bool WSP_GGML_OP_HAS_FINALIZE[WSP_GGML_OP_COUNT] = { 0 }; + +static void wsp_ggml_setup_op_has_task_pass(void) { + { // INIT + bool * p = WSP_GGML_OP_HAS_INIT; + + p[WSP_GGML_OP_ACC ] = true; + p[WSP_GGML_OP_MUL_MAT ] = true; + p[WSP_GGML_OP_OUT_PROD ] = true; + p[WSP_GGML_OP_SET ] = true; + p[WSP_GGML_OP_GET_ROWS_BACK ] = true; + p[WSP_GGML_OP_DIAG_MASK_INF ] = true; + p[WSP_GGML_OP_DIAG_MASK_ZERO ] = true; + p[WSP_GGML_OP_CONV_1D ] = true; + p[WSP_GGML_OP_CONV_2D ] = true; + p[WSP_GGML_OP_FLASH_ATTN_BACK ] = true; + p[WSP_GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } + + { // FINALIZE + bool * p = WSP_GGML_OP_HAS_FINALIZE; + + p[WSP_GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } +} + +// +// ggml context +// + +struct wsp_ggml_context { + size_t mem_size; + void * mem_buffer; + bool mem_buffer_owned; + bool no_alloc; + bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers + + int n_objects; + + struct wsp_ggml_object * objects_begin; + struct wsp_ggml_object * objects_end; + + struct wsp_ggml_scratch scratch; + struct wsp_ggml_scratch scratch_save; +}; + +struct wsp_ggml_context_container { + bool used; + + struct wsp_ggml_context context; +}; + +// +// NUMA support +// + +#define WSP_GGML_NUMA_MAX_NODES 8 +#define WSP_GGML_NUMA_MAX_CPUS 512 + +struct wsp_ggml_numa_node { + uint32_t cpus[WSP_GGML_NUMA_MAX_CPUS]; // hardware threads on this node + uint32_t n_cpus; +}; + +struct wsp_ggml_numa_nodes { + struct wsp_ggml_numa_node nodes[WSP_GGML_NUMA_MAX_NODES]; + uint32_t n_nodes; + uint32_t total_cpus; // hardware threads on system +}; + +// +// ggml state +// + +struct wsp_ggml_state { + struct wsp_ggml_context_container contexts[WSP_GGML_MAX_CONTEXTS]; + struct wsp_ggml_numa_nodes numa; +}; + +// global state +static struct wsp_ggml_state g_state; +static atomic_int g_state_barrier = 0; + +// barrier via spin lock +inline static void wsp_ggml_critical_section_start(void) { + int processing = atomic_fetch_add(&g_state_barrier, 1); + + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); // TODO: reconsider this + processing = atomic_fetch_add(&g_state_barrier, 1); + } +} + +// TODO: make this somehow automatically executed +// some sort of "sentry" mechanism +inline static void wsp_ggml_critical_section_end(void) { + atomic_fetch_sub(&g_state_barrier, 1); +} + +void wsp_ggml_numa_init(void) { + if (g_state.numa.n_nodes > 0) { + fprintf(stderr, "wsp_ggml_numa_init: NUMA already initialized\n"); + + return; + } + +#ifdef __linux__ + struct stat st; + char path[256]; + int rv; + + // enumerate nodes + while (g_state.numa.n_nodes < WSP_GGML_NUMA_MAX_NODES) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes); + WSP_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.n_nodes; + } + + // enumerate CPUs + while (g_state.numa.total_cpus < WSP_GGML_NUMA_MAX_CPUS) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus); + WSP_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.total_cpus; + } + + WSP_GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus); + + if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1) { + g_state.numa.n_nodes = 0; + return; + } + + for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) { + struct wsp_ggml_numa_node * node = &g_state.numa.nodes[n]; + WSP_GGML_PRINT_DEBUG("CPUs on node %u:", n); + node->n_cpus = 0; + for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c); + WSP_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) == 0) { + node->cpus[node->n_cpus++] = c; + WSP_GGML_PRINT_DEBUG(" %u", c); + } + } + WSP_GGML_PRINT_DEBUG("\n"); + } + + if (wsp_ggml_is_numa()) { + FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r"); + if (fptr != NULL) { + char buf[42]; + if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { + WSP_GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); + } + fclose(fptr); + } + } +#else + // TODO +#endif +} + +bool wsp_ggml_is_numa(void) { + return g_state.numa.n_nodes > 1; +} + +//////////////////////////////////////////////////////////////////////////////// + +void wsp_ggml_print_object(const struct wsp_ggml_object * obj) { + WSP_GGML_PRINT(" - wsp_ggml_object: offset = %zu, size = %zu, next = %p\n", + obj->offs, obj->size, (const void *) obj->next); +} + +void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx) { + struct wsp_ggml_object * obj = ctx->objects_begin; + + WSP_GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); + + while (obj != NULL) { + wsp_ggml_print_object(obj); + obj = obj->next; + } + + WSP_GGML_PRINT("%s: --- end ---\n", __func__); +} + +int64_t wsp_ggml_nelements(const struct wsp_ggml_tensor * tensor) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +int64_t wsp_ggml_nrows(const struct wsp_ggml_tensor * tensor) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +size_t wsp_ggml_nbytes(const struct wsp_ggml_tensor * tensor) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + // this should handle cases where the tensor is not contiguous in memory + // probaby just: + // + // return tensor->ne[3]*tensor->nb[3] + // + // is enough, but just in case, adding the second part + + return MAX(tensor->ne[3]*tensor->nb[3], (wsp_ggml_nelements(tensor)*WSP_GGML_TYPE_SIZE[tensor->type])/WSP_GGML_BLCK_SIZE[tensor->type]); +} + +size_t wsp_ggml_nbytes_split(const struct wsp_ggml_tensor * tensor, int nrows_split) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return (nrows_split*tensor->ne[0]*WSP_GGML_TYPE_SIZE[tensor->type])/WSP_GGML_BLCK_SIZE[tensor->type]; +} + +int wsp_ggml_blck_size(enum wsp_ggml_type type) { + return WSP_GGML_BLCK_SIZE[type]; +} + +size_t wsp_ggml_type_size(enum wsp_ggml_type type) { + return WSP_GGML_TYPE_SIZE[type]; +} + +float wsp_ggml_type_sizef(enum wsp_ggml_type type) { + return ((float)(WSP_GGML_TYPE_SIZE[type]))/WSP_GGML_BLCK_SIZE[type]; +} + +const char * wsp_ggml_type_name(enum wsp_ggml_type type) { + return WSP_GGML_TYPE_NAME[type]; +} + +const char * wsp_ggml_op_name(enum wsp_ggml_op op) { + return WSP_GGML_OP_NAME[op]; +} + +size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor) { + return WSP_GGML_TYPE_SIZE[tensor->type]; +} + +static inline bool wsp_ggml_is_scalar(const struct wsp_ggml_tensor * tensor) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool wsp_ggml_is_vector(const struct wsp_ggml_tensor * tensor) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool wsp_ggml_is_matrix(const struct wsp_ggml_tensor * tensor) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool wsp_ggml_can_mul_mat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + +static inline bool wsp_ggml_can_out_prod(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[1] == t1->ne[1]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + +bool wsp_ggml_is_quantized(enum wsp_ggml_type type) { + return WSP_GGML_IS_QUANTIZED[type]; +} + +enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype) { + enum wsp_ggml_type wtype = WSP_GGML_TYPE_COUNT; + + switch (ftype) { + case WSP_GGML_FTYPE_ALL_F32: wtype = WSP_GGML_TYPE_F32; break; + case WSP_GGML_FTYPE_MOSTLY_F16: wtype = WSP_GGML_TYPE_F16; break; + case WSP_GGML_FTYPE_MOSTLY_Q4_0: wtype = WSP_GGML_TYPE_Q4_0; break; + case WSP_GGML_FTYPE_MOSTLY_Q4_1: wtype = WSP_GGML_TYPE_Q4_1; break; + case WSP_GGML_FTYPE_MOSTLY_Q5_0: wtype = WSP_GGML_TYPE_Q5_0; break; + case WSP_GGML_FTYPE_MOSTLY_Q5_1: wtype = WSP_GGML_TYPE_Q5_1; break; + case WSP_GGML_FTYPE_MOSTLY_Q8_0: wtype = WSP_GGML_TYPE_Q8_0; break; + case WSP_GGML_FTYPE_MOSTLY_Q2_K: wtype = WSP_GGML_TYPE_Q2_K; break; + case WSP_GGML_FTYPE_MOSTLY_Q3_K: wtype = WSP_GGML_TYPE_Q3_K; break; + case WSP_GGML_FTYPE_MOSTLY_Q4_K: wtype = WSP_GGML_TYPE_Q4_K; break; + case WSP_GGML_FTYPE_MOSTLY_Q5_K: wtype = WSP_GGML_TYPE_Q5_K; break; + case WSP_GGML_FTYPE_MOSTLY_Q6_K: wtype = WSP_GGML_TYPE_Q6_K; break; + case WSP_GGML_FTYPE_UNKNOWN: wtype = WSP_GGML_TYPE_COUNT; break; + case WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = WSP_GGML_TYPE_COUNT; break; + } + + WSP_GGML_ASSERT(wtype != WSP_GGML_TYPE_COUNT); + + return wtype; +} + +size_t wsp_ggml_tensor_overhead(void) { + return WSP_GGML_OBJECT_SIZE + WSP_GGML_TENSOR_SIZE + 16; +} + +bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor) { + return tensor->nb[0] > tensor->nb[1]; +} + +bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == WSP_GGML_TYPE_SIZE[tensor->type] && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/WSP_GGML_BLCK_SIZE[tensor->type] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +bool wsp_ggml_is_permuted(const struct wsp_ggml_tensor * tensor) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; +} + +static inline bool wsp_ggml_is_padded_1d(const struct wsp_ggml_tensor * tensor) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == WSP_GGML_TYPE_SIZE[tensor->type] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static inline bool wsp_ggml_are_same_shape(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0] ) && + (t0->ne[1] == t1->ne[1] ) && + (t0->ne[2] == t1->ne[2] ) && + (t0->ne[3] == t1->ne[3] ); +} + +// check if t1 can be represented as a repeatition of t0 +static inline bool wsp_ggml_can_repeat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return + (t1->ne[0]%t0->ne[0] == 0) && + (t1->ne[1]%t0->ne[1] == 0) && + (t1->ne[2]%t0->ne[2] == 0) && + (t1->ne[3]%t0->ne[3] == 0); +} + +static inline bool wsp_ggml_can_repeat_rows(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1) { + static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && wsp_ggml_can_repeat(t0, t1); +} + +static inline int wsp_ggml_up32(int n) { + return (n + 31) & ~31; +} + +//static inline int wsp_ggml_up64(int n) { +// return (n + 63) & ~63; +//} + +static inline int wsp_ggml_up(int n, int m) { + // assert m is a power of 2 + WSP_GGML_ASSERT((m & (m - 1)) == 0); + return (n + m - 1) & ~(m - 1); +} + +// assert that pointer is aligned to WSP_GGML_MEM_ALIGN +#define wsp_ggml_assert_aligned(ptr) \ + WSP_GGML_ASSERT(((uintptr_t) (ptr))%WSP_GGML_MEM_ALIGN == 0) + +//////////////////////////////////////////////////////////////////////////////// + +struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params) { + // make this function thread safe + wsp_ggml_critical_section_start(); + + static bool is_first_call = true; + + if (is_first_call) { + // initialize time system (required on Windows) + wsp_ggml_time_init(); + + // initialize GELU, Quick GELU, SILU and EXP F32 tables + { + const uint64_t t_start = wsp_ggml_time_us(); UNUSED(t_start); + + wsp_ggml_fp16_t ii; + for (int i = 0; i < (1 << 16); ++i) { + uint16_t ui = i; + memcpy(&ii, &ui, sizeof(ii)); + const float f = table_f32_f16[i] = WSP_GGML_COMPUTE_FP16_TO_FP32(ii); + table_gelu_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_gelu_f32(f)); + table_gelu_quick_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_gelu_quick_f32(f)); + table_silu_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_silu_f32(f)); + table_exp_f16[i] = WSP_GGML_FP32_TO_FP16(expf(f)); + } + + const uint64_t t_end = wsp_ggml_time_us(); UNUSED(t_end); + + WSP_GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + + // initialize g_state + { + const uint64_t t_start = wsp_ggml_time_us(); UNUSED(t_start); + + g_state = (struct wsp_ggml_state) { + /*.contexts =*/ { { 0 } }, + /*.numa =*/ { + .n_nodes = 0, + .total_cpus = 0, + }, + }; + + for (int i = 0; i < WSP_GGML_MAX_CONTEXTS; ++i) { + g_state.contexts[i].used = false; + } + + const uint64_t t_end = wsp_ggml_time_us(); UNUSED(t_end); + + WSP_GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + +#if defined(WSP_GGML_USE_CUBLAS) + wsp_ggml_init_cublas(); +#elif defined(WSP_GGML_USE_CLBLAST) + wsp_ggml_cl_init(); +#endif + + wsp_ggml_setup_op_has_task_pass(); + + is_first_call = false; + } + + // find non-used context in g_state + struct wsp_ggml_context * ctx = NULL; + + for (int i = 0; i < WSP_GGML_MAX_CONTEXTS; i++) { + if (!g_state.contexts[i].used) { + g_state.contexts[i].used = true; + ctx = &g_state.contexts[i].context; + + WSP_GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); + break; + } + } + + if (ctx == NULL) { + WSP_GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); + + wsp_ggml_critical_section_end(); + + return NULL; + } + + const size_t mem_size = (params.mem_size + WSP_GGML_MEM_ALIGN - 1) & ~(WSP_GGML_MEM_ALIGN - 1); + + *ctx = (struct wsp_ggml_context) { + /*.mem_size =*/ mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : WSP_GGML_ALIGNED_MALLOC(mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.no_alloc =*/ params.no_alloc, + /*.no_alloc_save =*/ params.no_alloc, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + /*.scratch =*/ { 0, 0, NULL, }, + /*.scratch_save =*/ { 0, 0, NULL, }, + }; + + WSP_GGML_ASSERT(ctx->mem_buffer != NULL); + + wsp_ggml_assert_aligned(ctx->mem_buffer); + + WSP_GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + + wsp_ggml_critical_section_end(); + + return ctx; +} + +void wsp_ggml_free(struct wsp_ggml_context * ctx) { + // make this function thread safe + wsp_ggml_critical_section_start(); + + bool found = false; + + for (int i = 0; i < WSP_GGML_MAX_CONTEXTS; i++) { + if (&g_state.contexts[i].context == ctx) { + g_state.contexts[i].used = false; + + WSP_GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n", + __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size); + + if (ctx->mem_buffer_owned) { + WSP_GGML_ALIGNED_FREE(ctx->mem_buffer); + } + + found = true; + break; + } + } + + if (!found) { + WSP_GGML_PRINT_DEBUG("%s: context not found\n", __func__); + } + + wsp_ggml_critical_section_end(); +} + +size_t wsp_ggml_used_mem(const struct wsp_ggml_context * ctx) { + return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; +} + +size_t wsp_ggml_set_scratch(struct wsp_ggml_context * ctx, struct wsp_ggml_scratch scratch) { + const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; + + ctx->scratch = scratch; + + return result; +} + +void wsp_ggml_set_no_alloc(struct wsp_ggml_context * ctx, bool no_alloc) { + ctx->no_alloc = no_alloc; +} + +void * wsp_ggml_get_mem_buffer(const struct wsp_ggml_context * ctx) { + return ctx->mem_buffer; +} + +size_t wsp_ggml_get_mem_size(const struct wsp_ggml_context * ctx) { + return ctx->mem_size; +} + +size_t wsp_ggml_get_max_tensor_size(const struct wsp_ggml_context * ctx) { + size_t max_size = 0; + + struct wsp_ggml_object * obj = ctx->objects_begin; + + while (obj != NULL) { + struct wsp_ggml_tensor * tensor = (struct wsp_ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs); + + const size_t size = wsp_ggml_nbytes(tensor); + + if (max_size < size) { + max_size = size; + } + + obj = obj->next; + } + + return max_size; +} + +// IMPORTANT: +// when creating "opt" tensors, always save and load the scratch buffer +// this is an error prone process, but it is necessary to support inplace +// operators when using scratch buffers +// TODO: implement a better way +void wsp_ggml_scratch_save(struct wsp_ggml_context * ctx) { + // this is needed to allow opt tensors to store their data + // TODO: again, need to find a better way + ctx->no_alloc_save = ctx->no_alloc; + ctx->no_alloc = false; + + ctx->scratch_save = ctx->scratch; + ctx->scratch.data = NULL; +} + +void wsp_ggml_scratch_load(struct wsp_ggml_context * ctx) { + ctx->no_alloc = ctx->no_alloc_save; + + ctx->scratch = ctx->scratch_save; +} + +//////////////////////////////////////////////////////////////////////////////// + +struct wsp_ggml_tensor * wsp_ggml_new_tensor_impl( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int n_dims, + const int64_t* ne, + void* data) { + // always insert objects at the end of the context's memory pool + struct wsp_ggml_object * obj_cur = ctx->objects_end; + + const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offs + cur_size; + + size_t size_needed = 0; + + if (data == NULL && !ctx->no_alloc) { + size_needed += WSP_GGML_TYPE_SIZE[type]*(ne[0]/WSP_GGML_BLCK_SIZE[type]); + for (int i = 1; i < n_dims; i++) { + size_needed *= ne[i]; + } + // align to WSP_GGML_MEM_ALIGN + size_needed = ((size_needed + WSP_GGML_MEM_ALIGN - 1)/WSP_GGML_MEM_ALIGN)*WSP_GGML_MEM_ALIGN; + } + + char * const mem_buffer = ctx->mem_buffer; + struct wsp_ggml_object * const obj_new = (struct wsp_ggml_object *)(mem_buffer + cur_end); + + if (ctx->scratch.data == NULL || data != NULL) { + size_needed += WSP_GGML_TENSOR_SIZE; + + if (cur_end + size_needed + WSP_GGML_OBJECT_SIZE > ctx->mem_size) { + WSP_GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed + WSP_GGML_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + + *obj_new = (struct wsp_ggml_object) { + .offs = cur_end + WSP_GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + }; + } else { + if (ctx->scratch.offs + size_needed > ctx->scratch.size) { + WSP_GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", + __func__, ctx->scratch.offs + size_needed, ctx->scratch.size); + assert(false); + return NULL; + } + + if (cur_end + WSP_GGML_TENSOR_SIZE + WSP_GGML_OBJECT_SIZE > ctx->mem_size) { + WSP_GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + WSP_GGML_TENSOR_SIZE + WSP_GGML_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + + data = (char * const) ctx->scratch.data + ctx->scratch.offs; + + *obj_new = (struct wsp_ggml_object) { + .offs = cur_end + WSP_GGML_OBJECT_SIZE, + .size = WSP_GGML_TENSOR_SIZE, + .next = NULL, + }; + + //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed); + + ctx->scratch.offs += size_needed; + } + + if (obj_cur != NULL) { + obj_cur->next = obj_new; + } else { + // this is the first object in this context + ctx->objects_begin = obj_new; + } + + ctx->objects_end = obj_new; + + //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); + + struct wsp_ggml_tensor * const result = (struct wsp_ggml_tensor *)(mem_buffer + obj_new->offs); + + wsp_ggml_assert_aligned(result); + + *result = (struct wsp_ggml_tensor) { + /*.type =*/ type, + /*.backend =*/ WSP_GGML_BACKEND_CPU, + /*.n_dims =*/ n_dims, + /*.ne =*/ { 1, 1, 1, 1 }, + /*.nb =*/ { 0, 0, 0, 0 }, + /*.op =*/ WSP_GGML_OP_NONE, + /*.is_param =*/ false, + /*.grad =*/ NULL, + /*.src0 =*/ NULL, + /*.src1 =*/ NULL, + /*.opt =*/ { NULL }, + /*.n_tasks =*/ 0, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data, + /*.name =*/ { 0 }, + /*.extra =*/ NULL, + /*.pad =*/ { 0 }, + }; + + // TODO: this should not be needed as long as we don't rely on aligned SIMD loads + //wsp_ggml_assert_aligned(result->data); + + for (int i = 0; i < n_dims; i++) { + result->ne[i] = ne[i]; + } + + result->nb[0] = WSP_GGML_TYPE_SIZE[type]; + result->nb[1] = result->nb[0]*(result->ne[0]/WSP_GGML_BLCK_SIZE[type]); + for (int i = 2; i < WSP_GGML_MAX_DIMS; i++) { + result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; + } + + ctx->n_objects++; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_new_tensor( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int n_dims, + const int64_t * ne) { + return wsp_ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL); +} + +struct wsp_ggml_tensor * wsp_ggml_new_tensor_1d( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int64_t ne0) { + return wsp_ggml_new_tensor(ctx, type, 1, &ne0); +} + +struct wsp_ggml_tensor * wsp_ggml_new_tensor_2d( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int64_t ne0, + int64_t ne1) { + const int64_t ne[2] = { ne0, ne1 }; + return wsp_ggml_new_tensor(ctx, type, 2, ne); +} + +struct wsp_ggml_tensor * wsp_ggml_new_tensor_3d( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + const int64_t ne[3] = { ne0, ne1, ne2 }; + return wsp_ggml_new_tensor(ctx, type, 3, ne); +} + +struct wsp_ggml_tensor * wsp_ggml_new_tensor_4d( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + return wsp_ggml_new_tensor(ctx, type, 4, ne); +} + +struct wsp_ggml_tensor * wsp_ggml_new_i32(struct wsp_ggml_context * ctx, int32_t value) { + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 1); + + wsp_ggml_scratch_load(ctx); + + wsp_ggml_set_i32(result, value); + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float value) { + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 1); + + wsp_ggml_scratch_load(ctx); + + wsp_ggml_set_f32(result, value); + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_dup_tensor(struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src) { + return wsp_ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL); +} + +struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor) { + memset(tensor->data, 0, wsp_ggml_nbytes(tensor)); + return tensor; +} + +struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value) { + const int n = wsp_ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case WSP_GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case WSP_GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case WSP_GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_f16(nc, (wsp_ggml_fp16_t *)(data + i*n1), value); + } + } break; + case WSP_GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } + + return tensor; +} + +struct wsp_ggml_tensor * wsp_ggml_set_f32(struct wsp_ggml_tensor * tensor, float value) { + const int n = wsp_ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case WSP_GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case WSP_GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case WSP_GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_f16(nc, (wsp_ggml_fp16_t *)(data + i*n1), value); + } + } break; + case WSP_GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + wsp_ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } + + return tensor; +} + +int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i) { + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case WSP_GGML_TYPE_I16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case WSP_GGML_TYPE_I32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case WSP_GGML_TYPE_F16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); + return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *)(tensor->data))[i]); + } break; + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value) { + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case WSP_GGML_TYPE_I16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case WSP_GGML_TYPE_I32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case WSP_GGML_TYPE_F16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); + ((wsp_ggml_fp16_t *)(tensor->data))[i] = WSP_GGML_FP32_TO_FP16(value); + } break; + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i) { + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case WSP_GGML_TYPE_I16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case WSP_GGML_TYPE_I32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case WSP_GGML_TYPE_F16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); + return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *)(tensor->data))[i]); + } break; + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value) { + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case WSP_GGML_TYPE_I16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case WSP_GGML_TYPE_I32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case WSP_GGML_TYPE_F16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); + ((wsp_ggml_fp16_t *)(tensor->data))[i] = WSP_GGML_FP32_TO_FP16(value); + } break; + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +void * wsp_ggml_get_data(const struct wsp_ggml_tensor * tensor) { + return tensor->data; +} + +float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor) { + assert(tensor->type == WSP_GGML_TYPE_F32); + return (float *)(tensor->data); +} + +const char * wsp_ggml_get_name(const struct wsp_ggml_tensor * tensor) { + return tensor->name; +} + +struct wsp_ggml_tensor * wsp_ggml_set_name(struct wsp_ggml_tensor * tensor, const char * name) { + strncpy(tensor->name, name, sizeof(tensor->name)); + tensor->name[sizeof(tensor->name) - 1] = '\0'; + return tensor; +} + +struct wsp_ggml_tensor * wsp_ggml_format_name(struct wsp_ggml_tensor * tensor, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + vsnprintf(tensor->name, sizeof(tensor->name), fmt, args); + va_end(args); + return tensor; +} + +struct wsp_ggml_tensor * wsp_ggml_view_tensor( + struct wsp_ggml_context * ctx, + const struct wsp_ggml_tensor * src) { + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data); + wsp_ggml_format_name(result, "%s (view)", src->name); + + result->nb[0] = src->nb[0]; + result->nb[1] = src->nb[1]; + result->nb[2] = src->nb[2]; + result->nb[3] = src->nb[3]; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name) { + struct wsp_ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + struct wsp_ggml_tensor * cur = (struct wsp_ggml_tensor *)(mem_buffer + obj->offs); + if (strcmp(cur->name, name) == 0) { + return cur; + } + + obj = obj->next; + } + + return NULL; +} + +//////////////////////////////////////////////////////////////////////////////// + +// wsp_ggml_dup + +struct wsp_ggml_tensor * wsp_ggml_dup_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_DUP; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_dup( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_dup_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_dup_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_dup_impl(ctx, a, true); +} + +// wsp_ggml_add + +struct wsp_ggml_tensor * wsp_ggml_add_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + bool inplace) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_ADD; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_add( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_add_impl(ctx, a, b, false); +} + +struct wsp_ggml_tensor * wsp_ggml_add_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_add_impl(ctx, a, b, true); +} + +// wsp_ggml_add1 + +struct wsp_ggml_tensor * wsp_ggml_add1_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + bool inplace) { + WSP_GGML_ASSERT(wsp_ggml_is_scalar(b)); + WSP_GGML_ASSERT(wsp_ggml_is_padded_1d(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_ADD1; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_add1( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_add1_impl(ctx, a, b, false); +} + +struct wsp_ggml_tensor * wsp_ggml_add1_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_add1_impl(ctx, a, b, true); +} + +// wsp_ggml_acc + +struct wsp_ggml_tensor * wsp_ggml_acc_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { + WSP_GGML_ASSERT(wsp_ggml_nelements(b) <= wsp_ggml_nelements(a)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a)); + WSP_GGML_ASSERT(a->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_F32); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * c = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 5); + + ((int32_t *) c->data)[0] = nb1; + ((int32_t *) c->data)[1] = nb2; + ((int32_t *) c->data)[2] = nb3; + ((int32_t *) c->data)[3] = offset; + ((int32_t *) c->data)[4] = inplace ? 1 : 0; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_ACC; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = c; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_acc( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return wsp_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); +} + +struct wsp_ggml_tensor * wsp_ggml_acc_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return wsp_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); +} + +// wsp_ggml_sub + +struct wsp_ggml_tensor * wsp_ggml_sub_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + bool inplace) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_SUB; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_sub( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_sub_impl(ctx, a, b, false); +} + +struct wsp_ggml_tensor * wsp_ggml_sub_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_sub_impl(ctx, a, b, true); +} + +// wsp_ggml_mul + +struct wsp_ggml_tensor * wsp_ggml_mul_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + bool inplace) { + // TODO: support less-strict constraint + // WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a)); + WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(b, a)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b)); + is_node = true; + } + + if (inplace) { + WSP_GGML_ASSERT(is_node == false); + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_MUL; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_mul( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_mul_impl(ctx, a, b, false); +} + +struct wsp_ggml_tensor * wsp_ggml_mul_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_mul_impl(ctx, a, b, true); +} + +// wsp_ggml_div + +struct wsp_ggml_tensor * wsp_ggml_div_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + bool inplace) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + WSP_GGML_ASSERT(is_node == false); + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_DIV; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_div( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_div_impl(ctx, a, b, false); +} + +struct wsp_ggml_tensor * wsp_ggml_div_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_div_impl(ctx, a, b, true); +} + +// wsp_ggml_sqr + +struct wsp_ggml_tensor * wsp_ggml_sqr_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_SQR; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_sqr( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_sqr_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_sqr_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_sqr_impl(ctx, a, true); +} + +// wsp_ggml_sqrt + +struct wsp_ggml_tensor * wsp_ggml_sqrt_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_SQRT; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_sqrt( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_sqrt_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_sqrt_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_sqrt_impl(ctx, a, true); +} + + +// wsp_ggml_log + +struct wsp_ggml_tensor * wsp_ggml_log_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_LOG; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_log( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_log_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_log_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_log_impl(ctx, a, true); +} + +// wsp_ggml_sum + +struct wsp_ggml_tensor * wsp_ggml_sum( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = WSP_GGML_OP_SUM; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + + +// wsp_ggml_sum_rows + +struct wsp_ggml_tensor * wsp_ggml_sum_rows( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + int64_t ne[4] = {1,1,1,1}; + for (int i=1; in_dims; ++i) { + ne[i] = a->ne[i]; + } + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, a->n_dims, ne); + + result->op = WSP_GGML_OP_SUM_ROWS; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// wsp_ggml_mean + +struct wsp_ggml_tensor * wsp_ggml_mean( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + WSP_GGML_ASSERT(false); // TODO: implement + is_node = true; + } + + int64_t ne[WSP_GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, a->n_dims, ne); + + result->op = WSP_GGML_OP_MEAN; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// wsp_ggml_argmax + +struct wsp_ggml_tensor * wsp_ggml_argmax( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + WSP_GGML_ASSERT(wsp_ggml_is_matrix(a)); + bool is_node = false; + + if (a->grad) { + WSP_GGML_ASSERT(false); + is_node = true; + } + + int64_t ne[WSP_GGML_MAX_DIMS] = { a->ne[1], 1, 1, 1 }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, a->n_dims, ne); + + result->op = WSP_GGML_OP_ARGMAX; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// wsp_ggml_repeat + +struct wsp_ggml_tensor * wsp_ggml_repeat( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + WSP_GGML_ASSERT(wsp_ggml_can_repeat(a, b)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (wsp_ggml_are_same_shape(a, b) && !is_node) { + return a; + } + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + + result->op = WSP_GGML_OP_REPEAT; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_repeat_back + +struct wsp_ggml_tensor * wsp_ggml_repeat_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (wsp_ggml_are_same_shape(a, b) && !is_node) { + return a; + } + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + + result->op = WSP_GGML_OP_REPEAT_BACK; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_abs + +struct wsp_ggml_tensor * wsp_ggml_abs_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_ABS; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_abs( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_abs_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_abs_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_abs_impl(ctx, a, true); +} + + +// wsp_ggml_sgn + +struct wsp_ggml_tensor * wsp_ggml_sgn_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_SGN; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_sgn( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_sgn_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_sgn_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_sgn_impl(ctx, a, true); +} + +// wsp_ggml_neg + +struct wsp_ggml_tensor * wsp_ggml_neg_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_NEG; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_neg( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_neg_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_neg_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_neg_impl(ctx, a, true); +} + +// wsp_ggml_step + +struct wsp_ggml_tensor * wsp_ggml_step_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_STEP; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_step( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_step_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_step_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_step_impl(ctx, a, true); +} + +// wsp_ggml_tanh + +struct wsp_ggml_tensor * wsp_ggml_tanh_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_TANH; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_tanh( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_tanh_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_tanh_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_tanh_impl(ctx, a, true); +} + +// wsp_ggml_elu + +struct wsp_ggml_tensor * wsp_ggml_elu_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_ELU; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_elu( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_elu_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_elu_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_elu_impl(ctx, a, true); +} + +// wsp_ggml_relu + +struct wsp_ggml_tensor * wsp_ggml_relu_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_RELU; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_relu( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_relu_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_relu_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_relu_impl(ctx, a, true); +} + +// wsp_ggml_gelu + +struct wsp_ggml_tensor * wsp_ggml_gelu_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_GELU; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_gelu( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_gelu_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_gelu_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_gelu_impl(ctx, a, true); +} + +// wsp_ggml_gelu_quick + +struct wsp_ggml_tensor * wsp_ggml_gelu_quick_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_GELU_QUICK; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_gelu_quick( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_gelu_quick_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_gelu_quick_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_gelu_quick_impl(ctx, a, true); +} + +// wsp_ggml_silu + +struct wsp_ggml_tensor * wsp_ggml_silu_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_SILU; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_silu( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_silu_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_silu_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_silu_impl(ctx, a, true); +} + +// wsp_ggml_silu_back + +struct wsp_ggml_tensor * wsp_ggml_silu_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + bool is_node = false; + + if (a->grad || b->grad) { + // TODO: implement backward + is_node = true; + } + + struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_SILU_BACK; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_norm + +struct wsp_ggml_tensor * wsp_ggml_norm_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_NORM; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_norm( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_norm_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_norm_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_norm_impl(ctx, a, true); +} + +struct wsp_ggml_tensor * wsp_ggml_rms_norm_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_RMS_NORM; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_rms_norm( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_rms_norm_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_rms_norm_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_rms_norm_impl(ctx, a, true); +} + +struct wsp_ggml_tensor * wsp_ggml_rms_norm_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + bool is_node = false; + + if (a->grad) { + // TODO: implement backward + is_node = true; + } + + struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_RMS_NORM_BACK; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + + +// wsp_ggml_mul_mat + +struct wsp_ggml_tensor * wsp_ggml_mul_mat( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(a, b)); + WSP_GGML_ASSERT(!wsp_ggml_is_transposed(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + + result->op = WSP_GGML_OP_MUL_MAT; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_out_prod + +struct wsp_ggml_tensor * wsp_ggml_out_prod( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + WSP_GGML_ASSERT(wsp_ggml_can_out_prod(a, b)); + WSP_GGML_ASSERT(!wsp_ggml_is_transposed(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + + result->op = WSP_GGML_OP_OUT_PROD; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_scale + +struct wsp_ggml_tensor * wsp_ggml_scale_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + bool inplace) { + WSP_GGML_ASSERT(wsp_ggml_is_scalar(b)); + WSP_GGML_ASSERT(wsp_ggml_is_padded_1d(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_SCALE; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_scale( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_scale_impl(ctx, a, b, false); +} + +struct wsp_ggml_tensor * wsp_ggml_scale_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_scale_impl(ctx, a, b, true); +} + +// wsp_ggml_set + +struct wsp_ggml_tensor * wsp_ggml_set_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { + WSP_GGML_ASSERT(wsp_ggml_nelements(a) >= wsp_ggml_nelements(b)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // make a view of the destination + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * c = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 5); + + (( int32_t * ) c->data)[0] = nb1; + (( int32_t * ) c->data)[1] = nb2; + (( int32_t * ) c->data)[2] = nb3; + (( int32_t * ) c->data)[3] = offset; + (( int32_t * ) c->data)[4] = inplace ? 1 : 0; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_SET; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = c; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_set( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return wsp_ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false); +} + +struct wsp_ggml_tensor * wsp_ggml_set_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return wsp_ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true); +} + +struct wsp_ggml_tensor * wsp_ggml_set_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t offset) { + return wsp_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); +} + +struct wsp_ggml_tensor * wsp_ggml_set_1d_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t offset) { + return wsp_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); +} + +struct wsp_ggml_tensor * wsp_ggml_set_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t offset) { + return wsp_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); +} + +struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t offset) { + return wsp_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); +} + + +// wsp_ggml_cpy + +struct wsp_ggml_tensor * wsp_ggml_cpy_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + bool inplace) { + WSP_GGML_ASSERT(wsp_ggml_nelements(a) == wsp_ggml_nelements(b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + // make a view of the destination + struct wsp_ggml_tensor * result = wsp_ggml_view_tensor(ctx, b); + if (strlen(b->name) > 0) { + wsp_ggml_format_name(result, "%s (copy of %s)", b->name, a->name); + } else { + wsp_ggml_format_name(result, "%s (copy)", a->name); + } + + result->op = WSP_GGML_OP_CPY; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_cpy( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_cpy_impl(ctx, a, b, false); +} + +struct wsp_ggml_tensor * wsp_ggml_cpy_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_cpy_impl(ctx, a, b, true); +} + +// wsp_ggml_cont + +struct wsp_ggml_tensor * wsp_ggml_cont_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + wsp_ggml_format_name(result, "%s (cont)", a->name); + + result->op = WSP_GGML_OP_CONT; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_cont( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_cont_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_cont_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_cont_impl(ctx, a, true); +} + +// wsp_ggml_reshape + +struct wsp_ggml_tensor * wsp_ggml_reshape( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(b)); + WSP_GGML_ASSERT(wsp_ggml_nelements(a) == wsp_ggml_nelements(b)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (b->grad) { + // gradient propagation is not supported + //WSP_GGML_ASSERT(false); + } + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data); + wsp_ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = WSP_GGML_OP_RESHAPE; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_reshape_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a)); + WSP_GGML_ASSERT(wsp_ggml_nelements(a) == ne0); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[1] = { ne0 }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, 1, ne, a->data); + wsp_ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = WSP_GGML_OP_RESHAPE; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_reshape_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a)); + WSP_GGML_ASSERT(wsp_ggml_nelements(a) == ne0*ne1); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[2] = { ne0, ne1 }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data); + wsp_ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = WSP_GGML_OP_RESHAPE; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_reshape_3d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a)); + WSP_GGML_ASSERT(wsp_ggml_nelements(a) == ne0*ne1*ne2); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[3] = { ne0, ne1, ne2 }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data); + wsp_ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = WSP_GGML_OP_RESHAPE; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + + +struct wsp_ggml_tensor * wsp_ggml_reshape_4d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a)); + WSP_GGML_ASSERT(wsp_ggml_nelements(a) == ne0*ne1*ne2*ne3); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, 4, ne, a->data); + wsp_ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = WSP_GGML_OP_RESHAPE; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// wsp_ggml_view_1d + +struct wsp_ggml_tensor * wsp_ggml_view_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + size_t offset) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); + wsp_ggml_format_name(result, "%s (view)", a->name); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * offs = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 2); + wsp_ggml_set_name(offs, "offset"); + memcpy(offs->data, &offset, 2*sizeof(int32_t)); + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_VIEW; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = offs; + + return result; +} + +// wsp_ggml_view_2d + +struct wsp_ggml_tensor * wsp_ggml_view_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, + size_t offset) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[WSP_GGML_MAX_DIMS] = { ne0, ne1, 1, 1 }; + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); + wsp_ggml_format_name(result, "%s (view)", a->name); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * offs = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 2); + wsp_ggml_set_name(offs, "offset"); + memcpy(offs->data, &offset, 2*sizeof(int32_t)); + + wsp_ggml_scratch_load(ctx); + + result->nb[1] = nb1; + result->nb[2] = result->nb[1]*ne1; + result->nb[3] = result->nb[2]; + + result->op = WSP_GGML_OP_VIEW; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = offs; + + return result; +} + +// wsp_ggml_view_3d + +struct wsp_ggml_tensor * wsp_ggml_view_3d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, + size_t nb2, + size_t offset) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[WSP_GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 }; + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset); + wsp_ggml_format_name(result, "%s (view)", a->name); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * offs = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 2); + wsp_ggml_set_name(offs, "offset"); + memcpy(offs->data, &offset, 2*sizeof(int32_t)); + + wsp_ggml_scratch_load(ctx); + + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = result->nb[2]*ne2; + + result->op = WSP_GGML_OP_VIEW; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = offs; + + return result; +} + +// wsp_ggml_view_4d + +struct wsp_ggml_tensor * wsp_ggml_view_4d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[WSP_GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 }; + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset); + wsp_ggml_format_name(result, "%s (view)", a->name); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * offs = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 2); + wsp_ggml_set_name(offs, "offset"); + memcpy(offs->data, &offset, 2*sizeof(int32_t)); + + wsp_ggml_scratch_load(ctx); + + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = nb3; + + result->op = WSP_GGML_OP_VIEW; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = offs; + + return result; +} + +// wsp_ggml_permute + +struct wsp_ggml_tensor * wsp_ggml_permute( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3) { + WSP_GGML_ASSERT(axis0 >= 0 && axis0 < WSP_GGML_MAX_DIMS); + WSP_GGML_ASSERT(axis1 >= 0 && axis1 < WSP_GGML_MAX_DIMS); + WSP_GGML_ASSERT(axis2 >= 0 && axis2 < WSP_GGML_MAX_DIMS); + WSP_GGML_ASSERT(axis3 >= 0 && axis3 < WSP_GGML_MAX_DIMS); + + WSP_GGML_ASSERT(axis0 != axis1); + WSP_GGML_ASSERT(axis0 != axis2); + WSP_GGML_ASSERT(axis0 != axis3); + WSP_GGML_ASSERT(axis1 != axis2); + WSP_GGML_ASSERT(axis1 != axis3); + WSP_GGML_ASSERT(axis2 != axis3); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = wsp_ggml_view_tensor(ctx, a); + wsp_ggml_format_name(result, "%s (permuted)", a->name); + + int ne[WSP_GGML_MAX_DIMS]; + int nb[WSP_GGML_MAX_DIMS]; + + ne[axis0] = a->ne[0]; + ne[axis1] = a->ne[1]; + ne[axis2] = a->ne[2]; + ne[axis3] = a->ne[3]; + + nb[axis0] = a->nb[0]; + nb[axis1] = a->nb[1]; + nb[axis2] = a->nb[2]; + nb[axis3] = a->nb[3]; + + result->ne[0] = ne[0]; + result->ne[1] = ne[1]; + result->ne[2] = ne[2]; + result->ne[3] = ne[3]; + + result->nb[0] = nb[0]; + result->nb[1] = nb[1]; + result->nb[2] = nb[2]; + result->nb[3] = nb[3]; + + result->op = WSP_GGML_OP_PERMUTE; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + if (is_node) { + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 4); + + ((int32_t *) b->data)[0] = axis0; + ((int32_t *) b->data)[1] = axis1; + ((int32_t *) b->data)[2] = axis2; + ((int32_t *) b->data)[3] = axis3; + + wsp_ggml_scratch_load(ctx); + + result->opt[0] = b; + } + + return result; +} + +// wsp_ggml_transpose + +struct wsp_ggml_tensor * wsp_ggml_transpose( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = wsp_ggml_view_tensor(ctx, a); + wsp_ggml_format_name(result, "%s (transposed)", a->name); + + result->ne[0] = a->ne[1]; + result->ne[1] = a->ne[0]; + + result->nb[0] = a->nb[1]; + result->nb[1] = a->nb[0]; + + result->op = WSP_GGML_OP_TRANSPOSE; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// wsp_ggml_get_rows + +struct wsp_ggml_tensor * wsp_ggml_get_rows( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + WSP_GGML_ASSERT(wsp_ggml_is_matrix(a) && wsp_ggml_is_vector(b) && b->type == WSP_GGML_TYPE_I32); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // TODO: implement non F32 return + //struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, a->ne[0], b->ne[0]); + + result->op = WSP_GGML_OP_GET_ROWS; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_get_rows_back + +struct wsp_ggml_tensor * wsp_ggml_get_rows_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * c) { + WSP_GGML_ASSERT(wsp_ggml_is_matrix(a) && wsp_ggml_is_vector(b) && b->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(wsp_ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // TODO: implement non F32 return + //struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, c->ne[0], c->ne[1]); + + result->op = WSP_GGML_OP_GET_ROWS_BACK; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = c; + + return result; +} + +// wsp_ggml_diag + +struct wsp_ggml_tensor * wsp_ggml_diag( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + WSP_GGML_ASSERT(a->ne[1] == 1); + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne); + + result->op = WSP_GGML_OP_DIAG; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + + +// wsp_ggml_diag_mask_inf + +struct wsp_ggml_tensor * wsp_ggml_diag_mask_inf_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + bool inplace) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 2); + + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = inplace ? 1 : 0; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_DIAG_MASK_INF; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_diag_mask_inf( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past) { + return wsp_ggml_diag_mask_inf_impl(ctx, a, n_past, false); +} + + +struct wsp_ggml_tensor * wsp_ggml_diag_mask_inf_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past) { + return wsp_ggml_diag_mask_inf_impl(ctx, a, n_past, true); +} + +// wsp_ggml_diag_mask_zero + +struct wsp_ggml_tensor * wsp_ggml_diag_mask_zero_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + bool inplace) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 2); + wsp_ggml_set_name(b, "n_past, inplace"); + + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = inplace ? 1 : 0; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_DIAG_MASK_ZERO; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_diag_mask_zero( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past) { + return wsp_ggml_diag_mask_zero_impl(ctx, a, n_past, false); +} + +struct wsp_ggml_tensor * wsp_ggml_diag_mask_zero_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past) { + return wsp_ggml_diag_mask_zero_impl(ctx, a, n_past, true); +} + +// wsp_ggml_soft_max + +struct wsp_ggml_tensor * wsp_ggml_soft_max_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_SOFT_MAX; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_soft_max( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_soft_max_impl(ctx, a, false); +} + +struct wsp_ggml_tensor * wsp_ggml_soft_max_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_soft_max_impl(ctx, a, true); +} + + +// wsp_ggml_soft_max_back + +struct wsp_ggml_tensor * wsp_ggml_soft_max_back_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + bool inplace) { + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; // TODO : implement backward pass + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_SOFT_MAX_BACK; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_soft_max_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_soft_max_back_impl(ctx, a, b, false); +} + +struct wsp_ggml_tensor * wsp_ggml_soft_max_back_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + return wsp_ggml_soft_max_back_impl(ctx, a, b, true); +} + +// wsp_ggml_rope + +struct wsp_ggml_tensor * wsp_ggml_rope_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx, + bool inplace) { + WSP_GGML_ASSERT(n_past >= 0); + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 4); + + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = n_dims; + ((int32_t *) b->data)[2] = mode; + ((int32_t *) b->data)[3] = n_ctx; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_ROPE; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_rope( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx) { + return wsp_ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false); +} + +struct wsp_ggml_tensor * wsp_ggml_rope_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx) { + return wsp_ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true); +} + +// wsp_ggml_rope_back + +struct wsp_ggml_tensor * wsp_ggml_rope_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + int n_dims, + int mode) { + WSP_GGML_ASSERT(n_past >= 0); + WSP_GGML_ASSERT((mode & 4) == 0 && "wsp_ggml_rope_back() for ChatGLM not implemented yet"); + + bool is_node = false; + + if (a->grad) { + is_node = false; // TODO: implement backward + } + + struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 3); + wsp_ggml_set_name(b, "n_past, n_dims, mode"); + + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = n_dims; + ((int32_t *) b->data)[2] = mode; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_ROPE_BACK; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_alibi + +struct wsp_ggml_tensor * wsp_ggml_alibi( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + int n_head, + float bias_max) { + WSP_GGML_ASSERT(n_past >= 0); + bool is_node = false; + + if (a->grad) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + struct wsp_ggml_tensor * result = wsp_ggml_view_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 3); + + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = n_head; + WSP_GGML_ASSERT(sizeof(float) == sizeof(int32_t)); + (((float *) b->data)[2]) = bias_max; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_ALIBI; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_clamp + +struct wsp_ggml_tensor * wsp_ggml_clamp( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + float min, + float max) { + bool is_node = false; + + if (a->grad) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + struct wsp_ggml_tensor * result = wsp_ggml_view_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 2); + + ((float *) b->data)[0] = min; + ((float *) b->data)[1] = max; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_CLAMP; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_conv_1d + +static int64_t wsp_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; +} + +WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int p0, + int d0) { + WSP_GGML_ASSERT(wsp_ggml_is_matrix(b)); + WSP_GGML_ASSERT(a->ne[1] == b->ne[1]); + bool is_node = false; + + if (a->grad || b->grad) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), + a->ne[2], 1, 1, + }; + struct wsp_ggml_tensor* result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 2, ne); + + wsp_ggml_scratch_save(ctx); + struct wsp_ggml_tensor* c = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 3); + ((int32_t*)c->data)[0] = s0; + ((int32_t*)c->data)[1] = p0; + ((int32_t*)c->data)[2] = d0; + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_CONV_1D; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = c; + + return result; +} + +// wsp_ggml_conv_2d + +struct wsp_ggml_tensor* wsp_ggml_conv_2d( + struct wsp_ggml_context* ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + + WSP_GGML_ASSERT(b->ne[3] == 1); + WSP_GGML_ASSERT(a->ne[2] == b->ne[2]); + bool is_node = false; + + if (a->grad || b->grad) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), + wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1), + a->ne[3], 1, + }; + struct wsp_ggml_tensor* result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); + + wsp_ggml_scratch_save(ctx); + struct wsp_ggml_tensor* c = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 6); + ((int32_t*)c->data)[0] = s0; + ((int32_t*)c->data)[1] = s1; + ((int32_t*)c->data)[2] = p0; + ((int32_t*)c->data)[3] = p1; + ((int32_t*)c->data)[4] = d0; + ((int32_t*)c->data)[5] = d1; + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_CONV_2D; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = c; + + return result; + +} + +// wsp_ggml_conv_1d_ph + +struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s, + int d) { + return wsp_ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); +} + +// wsp_ggml_flash_attn + +struct wsp_ggml_tensor * wsp_ggml_flash_attn( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * q, + struct wsp_ggml_tensor * k, + struct wsp_ggml_tensor * v, + bool masked) { + WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + //struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, q); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, q->ne); + + result->op = WSP_GGML_OP_FLASH_ATTN; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = q; + result->src1 = k; + result->opt[0] = v; + result->opt[1] = wsp_ggml_new_i32(ctx, masked ? 1 : 0); + + return result; +} + +// wsp_ggml_flash_ff + +struct wsp_ggml_tensor * wsp_ggml_flash_ff( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b0, + struct wsp_ggml_tensor * b1, + struct wsp_ggml_tensor * c0, + struct wsp_ggml_tensor * c1) { + WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(b0, a)); + // TODO: more checks + + bool is_node = false; + + if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { + is_node = true; + } + + //struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, a->ne); + + result->op = WSP_GGML_OP_FLASH_FF; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b0; + result->opt[0] = b1; + result->opt[1] = c0; + result->opt[2] = c1; + + return result; +} + +// wsp_ggml_flash_attn_back + +struct wsp_ggml_tensor * wsp_ggml_flash_attn_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * q, + struct wsp_ggml_tensor * k, + struct wsp_ggml_tensor * v, + struct wsp_ggml_tensor * d, + bool masked) { + WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + // d shape [D,N,ne2,ne3] + // q shape [D,N,ne2,ne3] + // k shape [D,M,ne2,ne3] + // v shape [M,D,ne2,ne3] + + const int64_t D = q->ne[0]; + const int64_t N = q->ne[1]; + const int64_t M = k->ne[1]; + const int64_t ne2 = q->ne[2]; + const int64_t ne3 = q->ne[3]; + + WSP_GGML_ASSERT(k->ne[0] == D); + WSP_GGML_ASSERT(v->ne[0] == M); + WSP_GGML_ASSERT(v->ne[1] == D); + WSP_GGML_ASSERT(d->ne[0] == D); + WSP_GGML_ASSERT(d->ne[1] == N); + WSP_GGML_ASSERT(k->ne[2] == ne2); + WSP_GGML_ASSERT(k->ne[3] == ne3); + WSP_GGML_ASSERT(v->ne[2] == ne2); + WSP_GGML_ASSERT(v->ne[3] == ne3); + WSP_GGML_ASSERT(d->ne[2] == ne2); + WSP_GGML_ASSERT(d->ne[3] == ne3); + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + // when using this operation (in backwards pass) these grads are set. + // we don't want to create (big) grad of our result, so is_node is false. + is_node = false; + } + + // store gradients of q, k and v as continuous tensors concatenated in result. + // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3] + // gradq->data = result->data + // gradk->data = result->data + nb0*D*N*ne2*ne3 + // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3 + // note: v and gradv are actually transposed, i.e. v->ne[0] != D. + int64_t ne[4] = {D,M+N+M,ne2,ne3}; + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); + + result->op = WSP_GGML_OP_FLASH_ATTN_BACK; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = q; + result->src1 = k; + result->opt[0] = v; + result->opt[1] = d; + result->opt[2] = wsp_ggml_new_i32(ctx, masked ? 1 : 0); + + return result; +} + +// wsp_ggml_win_part + +struct wsp_ggml_tensor * wsp_ggml_win_part( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int w) { + WSP_GGML_ASSERT(a->ne[3] == 1); + WSP_GGML_ASSERT(a->type == WSP_GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // padding + const int px = (w - a->ne[1]%w)%w; + const int py = (w - a->ne[2]%w)%w; + + const int npx = (px + a->ne[1])/w; + const int npy = (py + a->ne[2])/w; + const int np = npx*npy; + + const int64_t ne[4] = { a->ne[0], w, w, np, }; + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 3); + + ((int32_t *) b->data)[0] = npx; + ((int32_t *) b->data)[1] = npy; + ((int32_t *) b->data)[2] = w; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_WIN_PART; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = b; + + return result; +} + +// wsp_ggml_win_unpart + +struct wsp_ggml_tensor * wsp_ggml_win_unpart( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int w0, + int h0, + int w) { + WSP_GGML_ASSERT(a->type == WSP_GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 3, ne); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, 1); + + ((int32_t *) b->data)[0] = w; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_WIN_UNPART; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = b; + + return result; +} + +// wsp_ggml_map_unary + +struct wsp_ggml_tensor * wsp_ggml_map_unary_impl_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + const wsp_ggml_unary_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor *result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * addr_tensor = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); + *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_MAP_UNARY; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->opt[0] = addr_tensor; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_map_unary_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + const wsp_ggml_unary_op_f32_t fun) { + return wsp_ggml_map_unary_impl_f32(ctx, a, fun, false); +} + +struct wsp_ggml_tensor * wsp_ggml_map_unary_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + const wsp_ggml_unary_op_f32_t fun) { + return wsp_ggml_map_unary_impl_f32(ctx, a, fun, true); +} + +// wsp_ggml_map_binary + +struct wsp_ggml_tensor * wsp_ggml_map_binary_impl_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + const wsp_ggml_binary_op_f32_t fun, + bool inplace) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor *result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * addr_tensor = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); + *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_MAP_BINARY; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = addr_tensor; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_map_binary_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + const wsp_ggml_binary_op_f32_t fun) { + return wsp_ggml_map_binary_impl_f32(ctx, a, b, fun, false); +} + +struct wsp_ggml_tensor * wsp_ggml_map_binary_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + const wsp_ggml_binary_op_f32_t fun) { + return wsp_ggml_map_binary_impl_f32(ctx, a, b, fun, true); +} + +// wsp_ggml_map_custom1 + +struct wsp_ggml_tensor * wsp_ggml_map_custom1_impl_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + const wsp_ggml_custom1_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct wsp_ggml_tensor *result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * addr_tensor = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); + *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_MAP_CUSTOM1; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->opt[0] = addr_tensor; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_map_custom1_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + const wsp_ggml_custom1_op_f32_t fun) { + return wsp_ggml_map_custom1_impl_f32(ctx, a, fun, false); +} + +struct wsp_ggml_tensor * wsp_ggml_map_custom1_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + const wsp_ggml_custom1_op_f32_t fun) { + return wsp_ggml_map_custom1_impl_f32(ctx, a, fun, true); +} + +// wsp_ggml_map_custom2 + +struct wsp_ggml_tensor * wsp_ggml_map_custom2_impl_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + const wsp_ggml_custom2_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor *result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * addr_tensor = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); + *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_MAP_CUSTOM2; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = addr_tensor; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_map_custom2_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + const wsp_ggml_custom2_op_f32_t fun) { + return wsp_ggml_map_custom2_impl_f32(ctx, a, b, fun, false); +} + +struct wsp_ggml_tensor * wsp_ggml_map_custom2_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + const wsp_ggml_custom2_op_f32_t fun) { + return wsp_ggml_map_custom2_impl_f32(ctx, a, b, fun, true); +} + +// wsp_ggml_map_custom3 + +struct wsp_ggml_tensor * wsp_ggml_map_custom3_impl_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * c, + const wsp_ggml_custom3_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad || b->grad || c->grad)) { + is_node = true; + } + + struct wsp_ggml_tensor *result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + + wsp_ggml_scratch_save(ctx); + + struct wsp_ggml_tensor * addr_tensor = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); + *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + + wsp_ggml_scratch_load(ctx); + + result->op = WSP_GGML_OP_MAP_CUSTOM3; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = addr_tensor; + result->opt[1] = c; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_map_custom3_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * c, + const wsp_ggml_custom3_op_f32_t fun) { + return wsp_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false); +} + +struct wsp_ggml_tensor * wsp_ggml_map_custom3_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * c, + const wsp_ggml_custom3_op_f32_t fun) { + return wsp_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true); +} + +// wsp_ggml_cross_entropy_loss + +struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b)); + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = WSP_GGML_OP_CROSS_ENTROPY_LOSS; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// wsp_ggml_cross_entropy_loss_back + +struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * c) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b)); + WSP_GGML_ASSERT(wsp_ggml_is_scalar(c)); + + struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a); + + result->op = WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK; + result->grad = NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = c; + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void wsp_ggml_set_param( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * tensor) { + tensor->is_param = true; + + WSP_GGML_ASSERT(tensor->grad == NULL); + tensor->grad = wsp_ggml_dup_tensor(ctx, tensor); +} + +// wsp_ggml_compute_forward_dup + +static void wsp_ggml_compute_forward_dup_same_cont( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(src0->type == dst->type); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const size_t nb00 = src0->nb[0]; + const size_t nb0 = dst->nb[0]; + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by elements + const int ne = wsp_ggml_nelements(dst); + const int dr = (ne + nth - 1) / nth; + const int ie0 = dr * ith; + const int ie1 = MIN(ie0 + dr, ne); + + if (ie0 < ie1) { + memcpy( + ((char *) dst->data + ie0*nb0), + ((char *) src0->data + ie0*nb00), + (ie1 - ie0) * WSP_GGML_TYPE_SIZE[src0->type]); + } + +} +static void wsp_ggml_compute_forward_dup_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + if (wsp_ggml_is_contiguous(src0) && wsp_ggml_is_contiguous(dst) && src0->type == dst->type) { + wsp_ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == WSP_GGML_TYPE_SIZE[src0->type] && nb0 == WSP_GGML_TYPE_SIZE[dst->type]) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + + if (wsp_ggml_is_contiguous(dst)) { + if (nb00 == sizeof(wsp_ggml_fp16_t)) { + if (dst->type == WSP_GGML_TYPE_F16) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == WSP_GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = WSP_GGML_FP16_TO_FP32(src0_ptr[i00]); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (wsp_ggml_is_quantized(dst->type)) { + quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + size_t id = 0; + size_t rs = nb0 * (ne00 / WSP_GGML_BLCK_SIZE[dst->type]); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = WSP_GGML_FP16_TO_FP32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + WSP_GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == WSP_GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = WSP_GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == WSP_GGML_TYPE_F16) { + size_t id = 0; + wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + WSP_GGML_ASSERT(false); // TODO: implement + } + } + return; + } + + // dst counters + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == WSP_GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(wsp_ggml_fp16_t)); + + if (++i10 == ne00) { + i10 = 0; + if (++i11 == ne01) { + i11 = 0; + if (++i12 == ne02) { + i12 = 0; + if (++i13 == ne03) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == WSP_GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = WSP_GGML_FP16_TO_FP32(*(const wsp_ggml_fp16_t *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + WSP_GGML_ASSERT(false); // TODO: implement + } +} + +static void wsp_ggml_compute_forward_dup_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + if (wsp_ggml_is_contiguous(src0) && wsp_ggml_is_contiguous(dst) && src0->type == dst->type) { + wsp_ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == WSP_GGML_TYPE_SIZE[src0->type] && nb0 == WSP_GGML_TYPE_SIZE[dst->type]) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (wsp_ggml_is_contiguous(dst)) { + // TODO: simplify + if (nb00 == sizeof(float)) { + if (dst->type == WSP_GGML_TYPE_F32) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == WSP_GGML_TYPE_F16) { + size_t id = 0; + wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = WSP_GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (wsp_ggml_is_quantized(dst->type)) { + quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; + + size_t id = 0; + size_t rs = nb0 * (ne00 / WSP_GGML_BLCK_SIZE[dst->type]); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + quantize_row_q(src0_ptr, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + WSP_GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == WSP_GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == WSP_GGML_TYPE_F16) { + size_t id = 0; + wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = WSP_GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + WSP_GGML_ASSERT(false); // TODO: implement + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == WSP_GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(float)); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == WSP_GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_FP32_TO_FP16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + WSP_GGML_ASSERT(false); // TODO: implement + } +} + +static void wsp_ggml_compute_forward_dup( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + if (wsp_ggml_is_contiguous(src0) && wsp_ggml_is_contiguous(dst) && src0->type == dst->type) { + wsp_ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_dup_f16(params, src0, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_dup_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_add + +static void wsp_ggml_compute_forward_add_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + WSP_GGML_ASSERT( nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + +#ifdef WSP_GGML_USE_ACCELERATE + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + wsp_ggml_vec_add_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + // } + // } + } + } else { + // src1 is not contiguous + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i0 = 0; i0 < ne0; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + + dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; + } + } + } +} + +static void wsp_ggml_compute_forward_add_f16_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16); + + WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } + else { + // src1 is not contiguous + WSP_GGML_ASSERT(false); + } +} + +static void wsp_ggml_compute_forward_add_f16_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16); + + WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(wsp_ggml_fp16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + WSP_GGML_FP16_TO_FP32(src1_ptr[i])); + } + } + } + else { + // src1 is not contiguous + WSP_GGML_ASSERT(false); + } +} + +static void wsp_ggml_compute_forward_add_q_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const enum wsp_ggml_type type = src0->type; + dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; + quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q; + + // we don't support permuted src0 or src1 + WSP_GGML_ASSERT(nb00 == WSP_GGML_TYPE_SIZE[type]); + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 <= nb1); + WSP_GGML_ASSERT(nb1 <= nb2); + WSP_GGML_ASSERT(nb2 <= nb3); + + WSP_GGML_ASSERT(wsp_ggml_is_quantized(src0->type)); + WSP_GGML_ASSERT(dst->type == src0->type); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + // src1 and dst are same shape as src0 => same indices + const int i13 = i03; + const int i12 = i02; + const int i11 = i01; + + const int i3 = i03; + const int i2 = i02; + const int i1 = i01; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); + void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne00); + // add src1 + wsp_ggml_vec_acc_f32(ne00, wdata, src1_row); + // quantize row to dst + quantize_row_q(wdata, dst_row, ne00); + } +} + +static void wsp_ggml_compute_forward_add( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_add_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F16: + { + if (src1->type == WSP_GGML_TYPE_F16) { + wsp_ggml_compute_forward_add_f16_f16(params, src0, src1, dst); + } + else if (src1->type == WSP_GGML_TYPE_F32) { + wsp_ggml_compute_forward_add_f16_f32(params, src0, src1, dst); + } + else { + WSP_GGML_ASSERT(false); + } + } break; + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: + case WSP_GGML_TYPE_Q5_0: + case WSP_GGML_TYPE_Q5_1: + case WSP_GGML_TYPE_Q8_0: + case WSP_GGML_TYPE_Q2_K: + case WSP_GGML_TYPE_Q3_K: + case WSP_GGML_TYPE_Q4_K: + case WSP_GGML_TYPE_Q5_K: + case WSP_GGML_TYPE_Q6_K: + { + wsp_ggml_compute_forward_add_q_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_add1 + +static void wsp_ggml_compute_forward_add1_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + WSP_GGML_ASSERT(wsp_ggml_is_scalar(src1)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + WSP_GGML_ASSERT( nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + +#ifdef WSP_GGML_USE_ACCELERATE + UNUSED(wsp_ggml_vec_add1_f32); + + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) src1->data), 0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + wsp_ggml_vec_add1_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + *(float *) src1->data); +#endif + } +} + +static void wsp_ggml_compute_forward_add1_f16_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + WSP_GGML_ASSERT(wsp_ggml_is_scalar(src1)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16); + + WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void wsp_ggml_compute_forward_add1_f16_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + WSP_GGML_ASSERT(wsp_ggml_is_scalar(src1)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // scalar to add + const float v = WSP_GGML_FP16_TO_FP32(*(wsp_ggml_fp16_t *) src1->data); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16); + + WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void wsp_ggml_compute_forward_add1_q_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + WSP_GGML_ASSERT(wsp_ggml_is_scalar(src1)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + const enum wsp_ggml_type type = src0->type; + dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; + quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q; + + // we don't support permuted src0 + WSP_GGML_ASSERT(nb00 == WSP_GGML_TYPE_SIZE[type]); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 <= nb1); + WSP_GGML_ASSERT(nb1 <= nb2); + WSP_GGML_ASSERT(nb2 <= nb3); + + WSP_GGML_ASSERT(wsp_ggml_is_quantized(src0->type)); + WSP_GGML_ASSERT(dst->type == src0->type); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03)); + void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 )); + + assert(ne0 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne0); + // add src1 + wsp_ggml_vec_acc1_f32(ne0, wdata, v); + // quantize row to dst + quantize_row_q(wdata, dst_row, ne0); + } +} + +static void wsp_ggml_compute_forward_add1( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_add1_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F16: + { + if (src1->type == WSP_GGML_TYPE_F16) { + wsp_ggml_compute_forward_add1_f16_f16(params, src0, src1, dst); + } + else if (src1->type == WSP_GGML_TYPE_F32) { + wsp_ggml_compute_forward_add1_f16_f32(params, src0, src1, dst); + } + else { + WSP_GGML_ASSERT(false); + } + } break; + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: + case WSP_GGML_TYPE_Q5_0: + case WSP_GGML_TYPE_Q5_1: + case WSP_GGML_TYPE_Q8_0: + case WSP_GGML_TYPE_Q8_1: + case WSP_GGML_TYPE_Q2_K: + case WSP_GGML_TYPE_Q3_K: + case WSP_GGML_TYPE_Q4_K: + case WSP_GGML_TYPE_Q5_K: + case WSP_GGML_TYPE_Q6_K: + { + wsp_ggml_compute_forward_add1_q_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + + +// wsp_ggml_compute_forward_acc + +static void wsp_ggml_compute_forward_acc_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0)); + + WSP_GGML_ASSERT(opt0->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(wsp_ggml_nelements(opt0) == 5); + + // view src0 and dst with these strides and data offset inbytes during acc + // nb0 is implicitely element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) opt0->data)[0]; + size_t nb2 = ((int32_t *) opt0->data)[1]; + size_t nb3 = ((int32_t *) opt0->data)[2]; + size_t offset = ((int32_t *) opt0->data)[3]; + bool inplace = (bool) ((int32_t *) opt0->data)[4]; + + if (!inplace && (params->type == WSP_GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + wsp_ggml_nbytes(dst)); + } + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(src1); + const int nc = src1->ne[0]; + + WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + + // src0 and dst as viewed during acc + const size_t nb0 = wsp_ggml_element_size(src0); + + const size_t nb00 = nb0; + const size_t nb01 = nb1; + const size_t nb02 = nb2; + const size_t nb03 = nb3; + + WSP_GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < wsp_ggml_nbytes(dst)); + WSP_GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < wsp_ggml_nbytes(src0)); + + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + +#ifdef WSP_GGML_USE_ACCELERATE + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1, + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc); +#else + wsp_ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + } +} + +static void wsp_ggml_compute_forward_acc( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_acc_f32(params, src0, src1, opt0, dst); + } break; + case WSP_GGML_TYPE_F16: + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: + case WSP_GGML_TYPE_Q5_0: + case WSP_GGML_TYPE_Q5_1: + case WSP_GGML_TYPE_Q8_0: + case WSP_GGML_TYPE_Q8_1: + case WSP_GGML_TYPE_Q2_K: + case WSP_GGML_TYPE_Q3_K: + case WSP_GGML_TYPE_Q4_K: + case WSP_GGML_TYPE_Q5_K: + case WSP_GGML_TYPE_Q6_K: + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_sub + +static void wsp_ggml_compute_forward_sub_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + WSP_GGML_ASSERT( nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + +#ifdef WSP_GGML_USE_ACCELERATE + vDSP_vsub( + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + wsp_ggml_vec_sub_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + // } + // } + } + } else { + // src1 is not contiguous + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i0 = 0; i0 < ne0; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + + dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; + } + } + } +} + +static void wsp_ggml_compute_forward_sub( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_sub_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_mul + +static void wsp_ggml_compute_forward_mul_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(src1, src0) && wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + const int ith = params->ith; + const int nth = params->nth; + +#ifdef WSP_GGML_USE_CLBLAST + if (src1->backend == WSP_GGML_BACKEND_GPU) { + if (ith == 0) { + wsp_ggml_cl_mul(src0, src1, dst); + } + return; + } +#endif + + const int64_t nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + WSP_GGML_ASSERT( nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb00 == sizeof(float)); + WSP_GGML_ASSERT(ne00 == ne10); + + if (nb10 == sizeof(float)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + +#ifdef WSP_GGML_USE_ACCELERATE + UNUSED(wsp_ggml_vec_mul_f32); + + vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); +#else + wsp_ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr); +#endif + // } + // } + } + } else { + // src1 is not contiguous + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); + + dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); + } + } + } +} + +static void wsp_ggml_compute_forward_mul( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_mul_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_div + +static void wsp_ggml_compute_forward_div_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int nr = wsp_ggml_nrows(src0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + WSP_GGML_ASSERT( nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + +#ifdef WSP_GGML_USE_ACCELERATE + vDSP_vdiv( + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + wsp_ggml_vec_div_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + // } + // } + } + } else { + // src1 is not contiguous + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i0 = 0; i0 < ne0; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + + dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); + } + } + } +} + +static void wsp_ggml_compute_forward_div( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_div_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_sqr + +static void wsp_ggml_compute_forward_sqr_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_sqr_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_sqr( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_sqr_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_sqrt + +static void wsp_ggml_compute_forward_sqrt_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_sqrt_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_sqrt( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_sqrt_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + + +// wsp_ggml_compute_forward_log + +static void wsp_ggml_compute_forward_log_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(params->ith == 0); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + WSP_GGML_ASSERT( dst->nb[0] == sizeof(float)); + WSP_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_log_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_log( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_log_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_sum + +static void wsp_ggml_compute_forward_sum_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_is_scalar(dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + assert(wsp_ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(float)); + + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + + wsp_ggml_float sum = 0; + wsp_ggml_float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + wsp_ggml_vec_sum_ggf(ne00, + &row_sum, + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + sum += row_sum; + } + } + } + ((float *) dst->data)[0] = sum; +} + +static void wsp_ggml_compute_forward_sum( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_sum_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_sum_rows + +static void wsp_ggml_compute_forward_sum_rows_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_ASSERT(src0->nb[0] == sizeof(float)); + WSP_GGML_ASSERT(dst->nb[0] == sizeof(float)); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + WSP_GGML_ASSERT(ne0 == 1); + WSP_GGML_ASSERT(ne1 == ne01); + WSP_GGML_ASSERT(ne2 == ne02); + WSP_GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float* src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float* dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + float row_sum = 0; + wsp_ggml_vec_sum_f32(ne00, &row_sum, src_row); + dst_row[0] = row_sum; + } + } + } +} + +static void wsp_ggml_compute_forward_sum_rows( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_sum_rows_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_mean + +static void wsp_ggml_compute_forward_mean_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + assert(ne0 == 1); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); + + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + wsp_ggml_vec_sum_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; + } + } + } +} + +static void wsp_ggml_compute_forward_mean( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_mean_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_argmax + +static void wsp_ggml_compute_forward_argmax_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + const size_t nb01 = src0->nb[1]; + const size_t nb0 = dst->nb[0]; + + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src = (float *) ((char *) src0->data + i1*nb01); + int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); + int v = 0; + wsp_ggml_vec_argmax_f32(ne00, &v, src); + dst_[0] = v; + } +} + +static void wsp_ggml_compute_forward_argmax( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_argmax_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_repeat + +static void wsp_ggml_compute_forward_repeat_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(params->ith == 0); + WSP_GGML_ASSERT(wsp_ggml_can_repeat(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + // guaranteed to be an integer due to the check in wsp_ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + WSP_GGML_ASSERT(nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + wsp_ggml_vec_cpy_f32(ne00, + (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), + (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); + } + } + } + } + } + } + } +} + +static void wsp_ggml_compute_forward_repeat( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_repeat_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_repeat_back + +static void wsp_ggml_compute_forward_repeat_back_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(params->ith == 0); + WSP_GGML_ASSERT(wsp_ggml_can_repeat(dst, src0)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + // guaranteed to be an integer due to the check in wsp_ggml_can_repeat + const int nr0 = (int)(ne00/ne0); + const int nr1 = (int)(ne01/ne1); + const int nr2 = (int)(ne02/ne2); + const int nr3 = (int)(ne03/ne3); + + // TODO: support for transposed / permuted tensors + WSP_GGML_ASSERT(nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + if (wsp_ggml_is_contiguous(dst)) { + wsp_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } else { + for (int k3 = 0; k3 < ne3; k3++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int k1 = 0; k1 < ne1; k1++) { + wsp_ggml_vec_set_f32(ne0, + (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), + 0); + } + } + } + } + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne3; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne1; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + wsp_ggml_vec_acc_f32(ne0, + (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), + (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); + } + } + } + } + } + } + } +} + +static void wsp_ggml_compute_forward_repeat_back( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_repeat_back_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_abs + +static void wsp_ggml_compute_forward_abs_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_abs_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_abs( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_abs_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_sgn + +static void wsp_ggml_compute_forward_sgn_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_sgn_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_sgn( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_sgn_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_neg + +static void wsp_ggml_compute_forward_neg_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_neg_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_neg( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_neg_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_step + +static void wsp_ggml_compute_forward_step_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_step_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_step( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_step_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_tanh + +static void wsp_ggml_compute_forward_tanh_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_tanh_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_tanh( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_tanh_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_elu + +static void wsp_ggml_compute_forward_elu_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_elu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_elu( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_elu_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_relu + +static void wsp_ggml_compute_forward_relu_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_relu( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_relu_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_gelu + +static void wsp_ggml_compute_forward_gelu_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + wsp_ggml_vec_gelu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void wsp_ggml_compute_forward_gelu( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_gelu_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_gelu_quick + +static void wsp_ggml_compute_forward_gelu_quick_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + wsp_ggml_vec_gelu_quick_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void wsp_ggml_compute_forward_gelu_quick( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_gelu_quick_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_silu + +static void wsp_ggml_compute_forward_silu_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + wsp_ggml_vec_silu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void wsp_ggml_compute_forward_silu( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_silu_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + + +// wsp_ggml_compute_forward_silu_back + +static void wsp_ggml_compute_forward_silu_back_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * grad, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(grad)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, grad)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + wsp_ggml_vec_silu_backward_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1])), + (float *) ((char *) grad->data + i1*(grad->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void wsp_ggml_compute_forward_silu_back( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * grad, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_silu_back_f32(params, src0, grad, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_norm + +static void wsp_ggml_compute_forward_norm_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + const float eps = 1e-5f; // TODO: make this a parameter + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + wsp_ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (wsp_ggml_float)x[i00]; + } + + float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + wsp_ggml_float sum2 = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sum2 += (wsp_ggml_float)(v*v); + } + + float variance = sum2/ne00; + const float scale = 1.0f/sqrtf(variance + eps); + + wsp_ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void wsp_ggml_compute_forward_norm( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_norm_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +static void wsp_ggml_compute_forward_rms_norm_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + const float eps = 1e-6f; // TODO: make this a parameter + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + wsp_ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (wsp_ggml_float)(x[i00] * x[i00]); + } + + const float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + // for (int i00 = 0; i00 < ne00; i00++) { + // y[i00] = x[i00]; + // } + + const float scale = 1.0f/sqrtf(mean + eps); + + wsp_ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void wsp_ggml_compute_forward_rms_norm( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_rms_norm_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + + +static void wsp_ggml_compute_forward_rms_norm_back_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst) && wsp_ggml_are_same_shape(src0, src1)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const float eps = 1e-6f; // TODO: make this a parameter + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + // src1 is same shape as src0 => same indices + const int64_t i11 = i01; + const int64_t i12 = i02; + const int64_t i13 = i03; + + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + + wsp_ggml_float sum_xx = 0.0; + wsp_ggml_float sum_xdz = 0.0; + + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum_xx += (wsp_ggml_float)(x[i00] * x[i00]); + sum_xdz += (wsp_ggml_float)(x[i00] * dz[i00]); + } + + //const float mean = (float)(sum_xx)/ne00; + const float mean_eps = (float)(sum_xx)/ne00 + eps; + const float sum_eps = (float)(sum_xx) + eps*ne00; + //const float mean_xdz = (float)(sum_xdz)/ne00; + // we could cache rms from forward pass to improve performance. + // to do this implement wsp_ggml_rms and compose wsp_ggml_rms_norm using wsp_ggml_rms. + //const float rms = sqrtf(mean_eps); + const float rrms = 1.0f / sqrtf(mean_eps); + //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) + + { + // z = rms_norm(x) + // + // rms_norm(src0) = + // scale( + // src0, + // div( + // 1, + // sqrt( + // add( + // scale( + // sum( + // sqr( + // src0)), + // (1.0/N)), + // eps)))); + + // postorder: + // ## op args grad + // 00 param src0 grad[#00] + // 01 const 1 + // 02 sqr (#00) grad[#02] + // 03 sum (#02) grad[#03] + // 04 const 1/N + // 05 scale (#03, #04) grad[#05] + // 06 const eps + // 07 add (#05, #06) grad[#07] + // 08 sqrt (#07) grad[#08] + // 09 div (#01,#08) grad[#09] + // 10 scale (#00,#09) grad[#10] + // + // backward pass, given grad[#10] + // #10: scale + // grad[#00] += scale(grad[#10],#09) + // grad[#09] += sum(mul(grad[#10],#00)) + // #09: div + // grad[#08] += neg(mul(grad[#09], div(#09,#08))) + // #08: sqrt + // grad[#07] += mul(grad[#08], div(0.5, #08)) + // #07: add + // grad[#05] += grad[#07] + // #05: scale + // grad[#03] += scale(grad[#05],#04) + // #03: sum + // grad[#02] += repeat(grad[#03], #02) + // #02: + // grad[#00] += scale(mul(#00, grad[#02]), 2.0) + // + // substitute and simplify: + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#02] = repeat(grad[#03], #02) + // grad[#02] = repeat(scale(grad[#05],#04), #02) + // grad[#02] = repeat(scale(grad[#07],#04), #02) + // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps)) + // a = b*c + d*e + // a = b*c*f/f + d*e*f/f + // a = (b*c*f + d*e*f)*(1/f) + // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c)) + // a = (b + d*e/c)*c + // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps) + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms + // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms + // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms + // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms + // a = (dz + x*div(-mean_xdz,mean_eps))*rrms + // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms) + // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + } + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // post-order: + // dx := x + // dx := scale(dx,-mean_xdz/mean_eps) + // dx := add(dx, dz) + // dx := scale(dx, rrms) + float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + wsp_ggml_vec_cpy_f32 (ne00, dx, x); + // wsp_ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); + wsp_ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); + wsp_ggml_vec_acc_f32 (ne00, dx, dz); + wsp_ggml_vec_scale_f32(ne00, dx, rrms); + } + } + } +} + +static void wsp_ggml_compute_forward_rms_norm_back( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_rms_norm_back_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + + +// wsp_ggml_compute_forward_mul_mat + +#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) +// helper function to determine if it is better to use BLAS or not +// for large matrices, BLAS is faster +static bool wsp_ggml_compute_forward_mul_mat_use_blas( + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + //const int64_t ne00 = src0->ne[0]; + //const int64_t ne01 = src0->ne[1]; + + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + if (wsp_ggml_is_contiguous(src0) && + wsp_ggml_is_contiguous(src1) && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { + + /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ + return true; + } + + return false; +} +#endif + +static void wsp_ggml_compute_forward_mul_mat_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + assert(ne02 == ne12); + assert(ne03 == ne13); + assert(ne2 == ne12); + assert(ne3 == ne13); + + // we don't support permuted src0 or src1 + assert(nb00 == sizeof(float)); + assert(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + assert(nb0 == sizeof(float)); + assert(nb0 <= nb1); + assert(nb1 <= nb2); + assert(nb2 <= nb3); + + assert(ne0 == ne01); + assert(ne1 == ne11); + assert(ne2 == ne02); + assert(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + +#if defined(WSP_GGML_USE_CLBLAST) + if (wsp_ggml_cl_can_mul_mat(src0, src1, dst)) { + if (params->ith == 0 && params->type == WSP_GGML_TASK_COMPUTE) { + wsp_ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); + } + return; + } +#endif + +#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) + if (wsp_ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + if (params->ith != 0) { + return; + } + + if (params->type == WSP_GGML_TASK_INIT) { + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne00, + 0.0f, d, ne01); + } + } + //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (wsp_ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == WSP_GGML_TASK_INIT) { + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // parallelize by src0 rows using wsp_ggml_vec_dot_f32 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + for (int64_t ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + wsp_ggml_vec_dot_f32(ne00, + (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), + (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); + } + } + + //int64_t t1 = wsp_ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void wsp_ggml_compute_forward_mul_mat_f16_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + //const int64_t ne = ne0*ne1*ne2*ne3; + + const int ith = params->ith; + const int nth = params->nth; + + WSP_GGML_ASSERT(ne02 == ne12); + WSP_GGML_ASSERT(ne03 == ne13); + WSP_GGML_ASSERT(ne2 == ne12); + WSP_GGML_ASSERT(ne3 == ne13); + + // TODO: we don't support permuted src0 + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb0 <= nb1); + WSP_GGML_ASSERT(nb1 <= nb2); + WSP_GGML_ASSERT(nb2 <= nb3); + + WSP_GGML_ASSERT(ne0 == ne01); + WSP_GGML_ASSERT(ne1 == ne11); + WSP_GGML_ASSERT(ne2 == ne02); + WSP_GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + +#if defined(WSP_GGML_USE_CLBLAST) + if (wsp_ggml_cl_can_mul_mat(src0, src1, dst)) { + if (params->ith == 0 && params->type == WSP_GGML_TASK_COMPUTE) { + wsp_ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); + } + return; + } +#endif + +#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) + if (wsp_ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { + return; + } + + if (params->type == WSP_GGML_TASK_INIT) { + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + float * const wdata = params->wdata; + { + size_t id = 0; + for (int64_t i01 = 0; i01 < ne01; ++i01) { + for (int64_t i00 = 0; i00 < ne00; ++i00) { + wdata[id++] = WSP_GGML_FP16_TO_FP32(*(wsp_ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + } + } + + assert(id*sizeof(float) <= params->wsize); + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne00, + 0.0f, d, ne01); + } + } + + /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (wsp_ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ + + return; + } +#endif + + if (params->type == WSP_GGML_TASK_INIT) { + wsp_ggml_fp16_t * const wdata = params->wdata; + + size_t id = 0; + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + wdata[id++] = WSP_GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + } + } + } + } + + WSP_GGML_ASSERT(id*sizeof(wsp_ggml_fp16_t) <= params->wsize); + + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // fp16 -> half the size, so divide by 2 + // TODO: do not support transposed src1 + assert(nb10/2 == sizeof(wsp_ggml_fp16_t)); + + // parallelize by src0 rows using wsp_ggml_vec_dot_f16 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + wsp_ggml_fp16_t * wdata = params->wdata; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int i13 = i03; + const int i12 = i02; + + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; + + wsp_ggml_fp16_t * src0_row = (wsp_ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + wsp_ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + for (int64_t ic = 0; ic < ne11; ++ic) { + wsp_ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + } + } + + //int64_t t1 = wsp_ggml_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void wsp_ggml_compute_forward_mul_mat_q_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + WSP_GGML_ASSERT(ne02 == ne12); + WSP_GGML_ASSERT(ne03 == ne13); + WSP_GGML_ASSERT(ne2 == ne12); + WSP_GGML_ASSERT(ne3 == ne13); + + const enum wsp_ggml_type type = src0->type; + quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; + vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; + enum wsp_ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type; + + // we don't support permuted src0 or src1 + WSP_GGML_ASSERT(nb00 == WSP_GGML_TYPE_SIZE[type]); + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb0 <= nb1); + WSP_GGML_ASSERT(nb1 <= nb2); + WSP_GGML_ASSERT(nb2 <= nb3); + + WSP_GGML_ASSERT(ne0 == ne01); + WSP_GGML_ASSERT(ne1 == ne11); + WSP_GGML_ASSERT(ne2 == ne02); + WSP_GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + +#if defined(WSP_GGML_USE_CLBLAST) + if (wsp_ggml_cl_can_mul_mat(src0, src1, dst)) { + if (params->ith == 0 && params->type == WSP_GGML_TASK_COMPUTE) { + wsp_ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); + } + return; + } +#endif + +#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) + if (wsp_ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + if (params->ith != 0) { + return; + } + + if (params->type == WSP_GGML_TASK_INIT) { + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + { + size_t id = 0; + for (int64_t i01 = 0; i01 < ne01; ++i01) { + dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + id += ne00; + } + + assert(id*sizeof(float) <= params->wsize); + } + + const float * x = wdata; + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne00, + 0.0f, d, ne01); + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (wsp_ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == WSP_GGML_TASK_INIT) { + char * wdata = params->wdata; + const size_t row_size = ne10*WSP_GGML_TYPE_SIZE[vec_dot_type]/WSP_GGML_BLCK_SIZE[vec_dot_type]; + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } + } + } + + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // parallelize by src0 rows using wsp_ggml_vec_dot_q + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + void * wdata = params->wdata; + const size_t row_size = ne00*WSP_GGML_TYPE_SIZE[vec_dot_type]/WSP_GGML_BLCK_SIZE[vec_dot_type]; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int i13 = i03; + const int i12 = i02; + + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + for (int64_t ic = 0; ic < ne11; ++ic) { + vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); + } + } + + //int64_t t1 = wsp_ggml_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void wsp_ggml_compute_forward_mul_mat( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: + case WSP_GGML_TYPE_Q5_0: + case WSP_GGML_TYPE_Q5_1: + case WSP_GGML_TYPE_Q8_0: + case WSP_GGML_TYPE_Q8_1: + case WSP_GGML_TYPE_Q2_K: + case WSP_GGML_TYPE_Q3_K: + case WSP_GGML_TYPE_Q4_K: + case WSP_GGML_TYPE_Q5_K: + case WSP_GGML_TYPE_Q6_K: + { + wsp_ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_mul_mat_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_out_prod + + +static void wsp_ggml_compute_forward_out_prod_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + WSP_GGML_ASSERT(ne02 == ne12); + WSP_GGML_ASSERT(ne03 == ne13); + WSP_GGML_ASSERT(ne2 == ne12); + WSP_GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 == sizeof(float)); + // WSP_GGML_ASSERT(nb0 <= nb1); + // WSP_GGML_ASSERT(nb1 <= nb2); + // WSP_GGML_ASSERT(nb2 <= nb3); + + WSP_GGML_ASSERT(ne0 == ne00); + WSP_GGML_ASSERT(ne1 == ne10); + WSP_GGML_ASSERT(ne2 == ne02); + WSP_GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: #if defined(WSP_GGML_USE_CUBLAS) wsp_ggml_cuda_out_prod + // TODO: #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) || defined(WSP_GGML_USE_CLBLAST) + + if (params->type == WSP_GGML_TASK_INIT) { + wsp_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + for (int64_t ir = ir0; ir < ir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + + for (int64_t i01 = 0; i01 < ne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + wsp_ggml_vec_mad_f32(ne0, d, s0, *s1); + // for (int64_t i0 = 0; i0 < ne0; ++i0) { + // d[i0] += s0[i0] * s1[i1]; + // } + } + } + + //int64_t t1 = wsp_ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void wsp_ggml_compute_forward_out_prod( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: + case WSP_GGML_TYPE_Q5_0: + case WSP_GGML_TYPE_Q5_1: + case WSP_GGML_TYPE_Q8_0: + case WSP_GGML_TYPE_Q8_1: + { + WSP_GGML_ASSERT(false); // todo + // wsp_ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F16: + { + WSP_GGML_ASSERT(false); // todo + // wsp_ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_out_prod_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_scale + +static void wsp_ggml_compute_forward_scale_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + WSP_GGML_ASSERT(wsp_ggml_is_scalar(src1)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // scale factor + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const size_t nb01 = src0->nb[1]; + + const size_t nb1 = dst->nb[1]; + + + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + wsp_ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); + } +} + +static void wsp_ggml_compute_forward_scale( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_scale_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_set + +static void wsp_ggml_compute_forward_set_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0)); + + WSP_GGML_ASSERT(opt0->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(wsp_ggml_nelements(opt0) == 5); + + // view src0 and dst with these strides and data offset inbytes during set + // nb0 is implicitely element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) opt0->data)[0]; + size_t nb2 = ((int32_t *) opt0->data)[1]; + size_t nb3 = ((int32_t *) opt0->data)[2]; + size_t offset = ((int32_t *) opt0->data)[3]; + bool inplace = (bool) ((int32_t *) opt0->data)[4]; + + if (!inplace && (params->type == WSP_GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + wsp_ggml_nbytes(dst)); + } + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(src1); + const int nc = src1->ne[0]; + + WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + + // src0 and dst as viewed during set + const size_t nb0 = wsp_ggml_element_size(src0); + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + WSP_GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= wsp_ggml_nbytes(dst)); + + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + + wsp_ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + } +} + +static void wsp_ggml_compute_forward_set( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_set_f32(params, src0, src1, opt0, dst); + } break; + case WSP_GGML_TYPE_F16: + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: + case WSP_GGML_TYPE_Q5_0: + case WSP_GGML_TYPE_Q5_1: + case WSP_GGML_TYPE_Q8_0: + case WSP_GGML_TYPE_Q8_1: + case WSP_GGML_TYPE_Q2_K: + case WSP_GGML_TYPE_Q3_K: + case WSP_GGML_TYPE_Q4_K: + case WSP_GGML_TYPE_Q5_K: + case WSP_GGML_TYPE_Q6_K: + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_cpy + +static void wsp_ggml_compute_forward_cpy( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + wsp_ggml_compute_forward_dup(params, src0, dst); +} + +// wsp_ggml_compute_forward_cont + +static void wsp_ggml_compute_forward_cont( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + wsp_ggml_compute_forward_dup(params, src0, dst); +} + +// wsp_ggml_compute_forward_reshape + +static void wsp_ggml_compute_forward_reshape( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(src0); + UNUSED(dst); +} + +// wsp_ggml_compute_forward_view + +static void wsp_ggml_compute_forward_view( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// wsp_ggml_compute_forward_permute + +static void wsp_ggml_compute_forward_permute( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// wsp_ggml_compute_forward_transpose + +static void wsp_ggml_compute_forward_transpose( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// wsp_ggml_compute_forward_get_rows + +static void wsp_ggml_compute_forward_get_rows_q( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nelements(src1); + const enum wsp_ggml_type type = src0->type; + dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == WSP_GGML_TYPE_SIZE[type]); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + dequantize_row_q( + (const void *) ((char *) src0->data + r*src0->nb[1]), + (float *) ((char *) dst->data + i*dst->nb[1]), nc); + } +} + +static void wsp_ggml_compute_forward_get_rows_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(wsp_ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + wsp_ggml_fp16_t v = ((wsp_ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = WSP_GGML_FP16_TO_FP32(v); + } + } +} + +static void wsp_ggml_compute_forward_get_rows_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + wsp_ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i*dst->nb[1]), + (float *) ((char *) src0->data + r*src0->nb[1])); + } +} + +static void wsp_ggml_compute_forward_get_rows( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: + case WSP_GGML_TYPE_Q5_0: + case WSP_GGML_TYPE_Q5_1: + case WSP_GGML_TYPE_Q8_0: + case WSP_GGML_TYPE_Q8_1: + case WSP_GGML_TYPE_Q2_K: + case WSP_GGML_TYPE_Q3_K: + case WSP_GGML_TYPE_Q4_K: + case WSP_GGML_TYPE_Q5_K: + case WSP_GGML_TYPE_Q6_K: + { + wsp_ggml_compute_forward_get_rows_q(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_get_rows_f16(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_get_rows_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// wsp_ggml_compute_forward_get_rows_back + +static void wsp_ggml_compute_forward_get_rows_back_f32_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(params->ith == 0); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(opt0, dst)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(opt0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + + wsp_ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nelements(src1); + + WSP_GGML_ASSERT( dst->ne[0] == nc); + WSP_GGML_ASSERT(src0->nb[0] == sizeof(wsp_ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + wsp_ggml_fp16_t v = ((wsp_ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += WSP_GGML_FP16_TO_FP32(v); + } + } +} + +static void wsp_ggml_compute_forward_get_rows_back_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(params->ith == 0); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(opt0, dst)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(opt0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + + // wsp_ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == WSP_GGML_TASK_INIT) { + memset(dst->data, 0, wsp_ggml_nbytes(dst)); + } + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nelements(src1); + + WSP_GGML_ASSERT( dst->ne[0] == nc); + WSP_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + wsp_ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) src0->data + i*src0->nb[1])); + } +} + + +static void wsp_ggml_compute_forward_get_rows_back( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// wsp_ggml_compute_forward_diag + +static void wsp_ggml_compute_forward_diag_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + WSP_GGML_ASSERT(ne00 == ne0); + WSP_GGML_ASSERT(ne00 == ne1); + WSP_GGML_ASSERT(ne01 == 1); + WSP_GGML_ASSERT(ne02 == ne2); + WSP_GGML_ASSERT(ne03 == ne3); + + WSP_GGML_ASSERT(nb00 == sizeof(float)); + WSP_GGML_ASSERT(nb0 == sizeof(float)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = 0; i1 < ne1; i1++) { + float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02); + for (int i0 = 0; i0 < i1; i0++) { + d[i0] = 0; + } + d[i1] = s[i1]; + for (int i0 = i1+1; i0 < ne0; i0++) { + d[i0] = 0; + } + } + } + } +} + +static void wsp_ggml_compute_forward_diag( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_diag_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_diag_mask_inf + +static void wsp_ggml_compute_forward_diag_mask_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst, + const float value) { + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(wsp_ggml_nelements(src1) == 2); + + const int ith = params->ith; + const int nth = params->nth; + + const int n_past = ((int32_t *) src1->data)[0]; + const bool inplace = (bool)((int32_t *) src1->data)[1]; + + WSP_GGML_ASSERT(n_past >= 0); + + if (!inplace && (params->type == WSP_GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0)); + memcpy( + ((char *) dst->data), + ((char *) src0->data), + wsp_ggml_nbytes(dst)); + } + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + const int nr = src0->ne[1]; + const int nz = n/nr; + + WSP_GGML_ASSERT( dst->nb[0] == sizeof(float)); + WSP_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int k = 0; k < nz; k++) { + for (int j = ith; j < nr; j += nth) { + for (int i = n_past; i < nc; i++) { + if (i > n_past + j) { + *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; + } + } + } + } +} + +static void wsp_ggml_compute_forward_diag_mask_inf( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, -INFINITY); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +static void wsp_ggml_compute_forward_diag_mask_zero( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, 0); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_soft_max + +static void wsp_ggml_compute_forward_soft_max_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *sp = (float *)((char *) src0->data + i1*src0->nb[1]); + float *dp = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(sp[i])); + } +#endif + + float max = -INFINITY; + wsp_ggml_vec_max_f32(nc, &max, sp); + + wsp_ggml_float sum = 0.0; + + uint16_t scvt; + for (int i = 0; i < nc; i++) { + if (sp[i] == -INFINITY) { + dp[i] = 0.0f; + } else { + // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max); + wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(sp[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt]); + sum += (wsp_ggml_float)val; + dp[i] = val; + } + } + + assert(sum > 0.0); + + sum = 1.0/sum; + wsp_ggml_vec_scale_f32(nc, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif + } +} + +static void wsp_ggml_compute_forward_soft_max( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_soft_max_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_soft_max_back + +static void wsp_ggml_compute_forward_soft_max_back_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src1, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); + float *y = (float *)((char *) src1->data + i1*src1->nb[1]); + float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(dy[i])); + assert(!isnan(y[i])); + } +#endif + // Jii = yi - yi*yi + // Jij = -yi*yj + // J = diag(y)-y.T*y + // dx = J * dy + // dxk = sum_i(Jki * dyi) + // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*dyk + // dxk = -yk * sum_i(yi * dyi) + yk*dyk + // dxk = -yk * dot(y, dy) + yk*dyk + // dxk = yk * (- dot(y, dy) + dyk) + // dxk = yk * (dyk - dot(y, dy)) + // + // post-order: + // dot_y_dy := dot(y, dy) + // dx := dy + // dx := dx - dot_y_dy + // dx := dx * y + + // linear runtime, no additional memory + float dot_y_dy = 0; + wsp_ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy); + wsp_ggml_vec_cpy_f32 (nc, dx, dy); + wsp_ggml_vec_acc1_f32(nc, dx, -dot_y_dy); + wsp_ggml_vec_mul_f32 (nc, dx, dx, y); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dx[i])); + assert(!isinf(dx[i])); + } +#endif + } +} + +static void wsp_ggml_compute_forward_soft_max_back( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_alibi + +static void wsp_ggml_compute_forward_alibi_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(wsp_ggml_nelements(src1) == 3); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; + + assert(n_past >= 0); + + const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int ne1 = src0->ne[1]; // seq_len_without_past + //const int ne2 = src0->ne[2]; // n_head -> this is k + //const int ne3 = src0->ne[3]; // 1 -> bsz + + const int n = wsp_ggml_nrows(src0); + const int ne2_ne3 = n/ne1; // ne2*ne3 + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + //const int nb3 = src0->nb[3]; + + assert(nb0 == sizeof(float)); + assert(ne1 + n_past == ne0); (void) n_past; + + // add alibi to src0 (KQ_scaled) + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + for (int i = 0; i < ne0; i++) { + for (int j = 0; j < ne1; j++) { + for (int k = 0; k < ne2_ne3; k++) { + float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); + float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); + + // TODO: k*nb2 or k*nb3 + + float m_k; + + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + pdst[0] = (i-ne0+1) * m_k + src[0]; + + } + } + } +} + +static void wsp_ggml_compute_forward_alibi_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(wsp_ggml_nelements(src1) == 3); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; + + assert(n_past >= 0); + + const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int ne1 = src0->ne[1]; // seq_len_without_past + //const int ne2 = src0->ne[2]; // n_head -> this is k + //const int ne3 = src0->ne[3]; // 1 -> bsz + + const int n = wsp_ggml_nrows(src0); + const int ne2_ne3 = n/ne1; // ne2*ne3 + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + //const int nb3 = src0->nb[3]; + + assert(nb0 == sizeof(wsp_ggml_fp16_t)); + assert(ne1 + n_past == ne0); (void) n_past; + + // add alibi to src0 (KQ_scaled) + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + for (int i = 0; i < ne0; i++) { + for (int j = 0; j < ne1; j++) { + for (int k = 0; k < ne2_ne3; k++) { + wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); + float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); + + // TODO: k*nb2 or k*nb3 + + float m_k; + + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + // we return F32 + pdst[0] = (i-ne0+1) * m_k + WSP_GGML_FP16_TO_FP32(src[0]); + } + } + } +} + +static void wsp_ggml_compute_forward_alibi( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_alibi_f16(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_alibi_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: + case WSP_GGML_TYPE_Q5_0: + case WSP_GGML_TYPE_Q5_1: + case WSP_GGML_TYPE_Q8_0: + case WSP_GGML_TYPE_Q8_1: + case WSP_GGML_TYPE_Q2_K: + case WSP_GGML_TYPE_Q3_K: + case WSP_GGML_TYPE_Q4_K: + case WSP_GGML_TYPE_Q5_K: + case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_Q8_K: + case WSP_GGML_TYPE_I8: + case WSP_GGML_TYPE_I16: + case WSP_GGML_TYPE_I32: + case WSP_GGML_TYPE_COUNT: + { + WSP_GGML_ASSERT(false); + } break; + } +} + + +// wsp_ggml_compute_forward_clamp + +static void wsp_ggml_compute_forward_clamp_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT(wsp_ggml_nelements(src1) == 2); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const float min = ((float *) src1->data)[0]; + const float max = ((float *) src1->data)[1]; + + const int ith = params->ith; + const int nth = params->nth; + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + WSP_GGML_ASSERT( nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + + for (int i = 0; i < nc; i++) { + dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); + } + } +} + +static void wsp_ggml_compute_forward_clamp( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_clamp_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F16: + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: + case WSP_GGML_TYPE_Q5_0: + case WSP_GGML_TYPE_Q5_1: + case WSP_GGML_TYPE_Q8_0: + case WSP_GGML_TYPE_Q8_1: + case WSP_GGML_TYPE_Q2_K: + case WSP_GGML_TYPE_Q3_K: + case WSP_GGML_TYPE_Q4_K: + case WSP_GGML_TYPE_Q5_K: + case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_Q8_K: + case WSP_GGML_TYPE_I8: + case WSP_GGML_TYPE_I16: + case WSP_GGML_TYPE_I32: + case WSP_GGML_TYPE_COUNT: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_rope + +static void wsp_ggml_compute_forward_rope_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(wsp_ggml_nelements(src1) == 4); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; + + assert(n_past >= 0); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(dst); + + WSP_GGML_ASSERT(n_dims <= ne0); + WSP_GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(10000.0, -2.0f/n_dims); + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta = (float)p; + + if (is_glm) { + theta = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta); + + theta *= theta_scale; + block_theta *= theta_scale; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + const float x2 = src[n_dims]; + const float x3 = src[n_dims/2*3]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; + dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; + } + } else if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } else { + // TODO: this is probably wrong, but I can't figure it out .. + // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } + } + } + } + } +} + +static void wsp_ggml_compute_forward_rope_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(wsp_ggml_nelements(src1) == 4); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; + + assert(n_past >= 0); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + WSP_GGML_ASSERT(nb0 == sizeof(wsp_ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(dst); + + WSP_GGML_ASSERT(n_dims <= ne0); + WSP_GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(10000.0, -2.0f/n_dims); + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta = (float)p; + + if (is_glm) { + theta = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta); + + theta *= theta_scale; + block_theta *= theta_scale; + + const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = WSP_GGML_FP16_TO_FP32(src[0]); + const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims/2]); + const float x2 = WSP_GGML_FP16_TO_FP32(src[n_dims]); + const float x3 = WSP_GGML_FP16_TO_FP32(src[n_dims/2*3]); + + dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + dst_data[n_dims] = WSP_GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); + dst_data[n_dims/2*3] = WSP_GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); + } + } if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = WSP_GGML_FP16_TO_FP32(src[0]); + const float x1 = WSP_GGML_FP16_TO_FP32(src[1]); + + dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[1] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } else { + // TODO: this is probably wrong, but I can't figure it out .. + // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = WSP_GGML_FP16_TO_FP32(src[0]); + const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims/2]); + + dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } + } + } + } + } +} + +static void wsp_ggml_compute_forward_rope( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_rope_back + +static void wsp_ggml_compute_forward_rope_back_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(src1->type == WSP_GGML_TYPE_I32); + assert(wsp_ggml_nelements(src1) == 3); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // y = rope(x, src1) + // dx = rope_back(dy, src1) + // src0 is dy, src1 contains options + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + assert(n_past >= 0); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(dst); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(10000.0, -2.0f/n_dims); + + const bool is_neox = mode & 2; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta = (float)p; + + if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float dy0 = dy[0]; + const float dy1 = dy[1]; + + dx[0] = dy0*cos_theta + dy1*sin_theta; + dx[1] = - dy0*sin_theta + dy1*cos_theta; + } + } else { + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float dy0 = dy[0]; + const float dy1 = dy[n_dims/2]; + + dx[0] = dy0*cos_theta + dy1*sin_theta; + dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta; + } + } + } + } + } + } +} + +static void wsp_ggml_compute_forward_rope_back_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + assert(src1->type == WSP_GGML_TYPE_I32); + assert(wsp_ggml_nelements(src1) == 3); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // y = rope(x, src1) + // dx = rope_back(dy, src1) + // src0 is dy, src1 contains options + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + assert(n_past >= 0); + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(wsp_ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = wsp_ggml_nrows(dst); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(10000.0, -2.0f/n_dims); + + const bool is_neox = mode & 2; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta = (float)p; + + if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const wsp_ggml_fp16_t * const dy = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + wsp_ggml_fp16_t * dx = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float dy0 = WSP_GGML_FP16_TO_FP32(dy[0]); + const float dy1 = WSP_GGML_FP16_TO_FP32(dy[1]); + + dx[0] = WSP_GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta); + dx[1] = WSP_GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta); + } + } else { + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const wsp_ggml_fp16_t * const dy = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + wsp_ggml_fp16_t * dx = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float dy0 = WSP_GGML_FP16_TO_FP32(dy[0]); + const float dy1 = WSP_GGML_FP16_TO_FP32(dy[n_dims/2]); + + dx[0] = WSP_GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta); + dx[n_dims/2] = WSP_GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta); + } + } + } + } + } + } +} + +static void wsp_ggml_compute_forward_rope_back( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_rope_back_f16(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_rope_back_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_conv_1d + +static void wsp_ggml_compute_forward_conv_1d_s1_ph_f16_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); + + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = wsp_ggml_up32(ne01); + + WSP_GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == WSP_GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + wsp_ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + wsp_ggml_fp16_t * dst_data = wdata; + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int64_t i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + wsp_ggml_vec_dot_f16(ew0, &v, + (wsp_ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (wsp_ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void wsp_ggml_compute_forward_conv_1d_s1_ph_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); + + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = wsp_ggml_up32(ne01); + + WSP_GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + WSP_GGML_ASSERT(nb00 == sizeof(float)); + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == WSP_GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int64_t i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + wsp_ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void wsp_ggml_compute_forward_conv_1d_s1_ph( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +static void wsp_ggml_compute_forward_conv_1d_s2_ph_f16_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); + + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = wsp_ggml_up32(ne01); + + WSP_GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == WSP_GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + wsp_ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + wsp_ggml_fp16_t * dst_data = wdata; + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int64_t i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + wsp_ggml_vec_dot_f16(ew0, &v, + (wsp_ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (wsp_ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void wsp_ggml_compute_forward_conv_1d_s2_ph_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); + + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = wsp_ggml_up32(ne01); + + WSP_GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + WSP_GGML_ASSERT(nb00 == sizeof(float)); + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == WSP_GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int64_t i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + wsp_ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void wsp_ggml_compute_forward_conv_1d_s2_ph( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_conv_1d + +static void wsp_ggml_compute_forward_conv_1d( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + const int32_t s0 = ((const int32_t*)(opt0->data))[0]; + const int32_t p0 = ((const int32_t*)(opt0->data))[1]; + const int32_t d0 = ((const int32_t*)(opt0->data))[2]; + WSP_GGML_ASSERT(d0 == 1); // dilation not supported + WSP_GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported + if (s0 == 1) { + wsp_ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst); + } else if (s0 == 2) { + wsp_ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst); + } else { + WSP_GGML_ASSERT(false); // only stride 1 and 2 supported + }; +} + +// wsp_ggml_compute_forward_conv_2d_sk_p0 + +static void wsp_ggml_compute_forward_conv_2d_sk_p0_f16_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); + + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk0 = ne00; + const int nk1 = ne01; + + // size of the convolution row - the kernel size unrolled across all channels + const int ew0 = nk0*nk1*ne02; + + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == WSP_GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare source data (src1) + { + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + + for (int i12 = 0; i12 < ne12; i12++) { + const float * const src = (float *)((char *) src1->data + i12*nb12); + wsp_ggml_fp16_t * dst_data = wdata; + + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + for (int ik1 = 0; ik1 < nk1; ik1++) { + for (int ik0 = 0; ik0 < nk0; ik0++) { + dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = + WSP_GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]); + } + } + } + } + } + } + + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // total patches in dst + const int np = ne2; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + + for (int i2 = ip0; i2 < ip1; i2++) { + float * dst_data = (float *)((char *) dst->data + i2*nb2); + + for (int i1 = 0; i1 < ne1; ++i1) { + for (int i0 = 0; i0 < ne0; ++i0) { + wsp_ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0, + (wsp_ggml_fp16_t *) ((char *) src0->data + i2*nb03), + (wsp_ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0); + } + } + } +} + +static void wsp_ggml_compute_forward_conv_2d_sk_p0( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + //wsp_ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst); + WSP_GGML_ASSERT(false); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_conv_2d + +static void wsp_ggml_compute_forward_conv_2d( + const struct wsp_ggml_compute_params* params, + const struct wsp_ggml_tensor* src0, + const struct wsp_ggml_tensor* src1, + const struct wsp_ggml_tensor* opt0, + struct wsp_ggml_tensor* dst) { + const int32_t s0 = ((const int32_t*)(opt0->data))[0]; + const int32_t s1 = ((const int32_t*)(opt0->data))[1]; + const int32_t p0 = ((const int32_t*)(opt0->data))[2]; + const int32_t p1 = ((const int32_t*)(opt0->data))[3]; + const int32_t d0 = ((const int32_t*)(opt0->data))[4]; + const int32_t d1 = ((const int32_t*)(opt0->data))[5]; + WSP_GGML_ASSERT(d0 == 1); // dilation not supported + WSP_GGML_ASSERT(d1 == 1); + WSP_GGML_ASSERT(p0 == 0); // padding not supported + WSP_GGML_ASSERT(p1 == 0); + + if (s0 == src0->ne[0] && s1 == src0->ne[1]) { + wsp_ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst); + } + else { + WSP_GGML_ASSERT(false); // only stride equal to kernel size is supported + }; +} + + +// wsp_ggml_compute_forward_flash_attn + +static void wsp_ggml_compute_forward_flash_attn_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * q, + const struct wsp_ggml_tensor * k, + const struct wsp_ggml_tensor * v, + const bool masked, + struct wsp_ggml_tensor * dst) { + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = wsp_ggml_up(M, WSP_GGML_SOFT_MAX_UNROLL); + + WSP_GGML_ASSERT(ne0 == D); + WSP_GGML_ASSERT(ne1 == N); + WSP_GGML_ASSERT(P >= 0); + + WSP_GGML_ASSERT(nbq0 == sizeof(float)); + WSP_GGML_ASSERT(nbk0 == sizeof(float)); + WSP_GGML_ASSERT(nbv0 == sizeof(float)); + + WSP_GGML_ASSERT(neq0 == D); + WSP_GGML_ASSERT(nek0 == D); + WSP_GGML_ASSERT(nev1 == D); + + WSP_GGML_ASSERT(neq1 == N); + WSP_GGML_ASSERT(nek1 == N + P); + WSP_GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb0 <= nb1); + WSP_GGML_ASSERT(nb1 <= nb2); + WSP_GGML_ASSERT(nb2 <= nb3); + + if (params->type == WSP_GGML_TASK_INIT) { + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using wsp_ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + wsp_ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + wsp_ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + wsp_ggml_vec_max_f32(M, &max, S); + + wsp_ggml_float sum = 0.0; + { +#ifdef WSP_GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + wsp_ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[WSP_GGML_SOFT_MAX_UNROLL]; + wsp_ggml_float sump[WSP_GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += WSP_GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < WSP_GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + sump[j] += (wsp_ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < WSP_GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + wsp_ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + wsp_ggml_vec_dot_f32(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S); + } + } +} + +static void wsp_ggml_compute_forward_flash_attn_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * q, + const struct wsp_ggml_tensor * k, + const struct wsp_ggml_tensor * v, + const bool masked, + struct wsp_ggml_tensor * dst) { + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = wsp_ggml_up(M, WSP_GGML_SOFT_MAX_UNROLL); + + WSP_GGML_ASSERT(ne0 == D); + WSP_GGML_ASSERT(ne1 == N); + WSP_GGML_ASSERT(P >= 0); + + WSP_GGML_ASSERT(nbq0 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nbk0 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nbv0 == sizeof(wsp_ggml_fp16_t)); + + WSP_GGML_ASSERT(neq0 == D); + WSP_GGML_ASSERT(nek0 == D); + WSP_GGML_ASSERT(nev1 == D); + + WSP_GGML_ASSERT(neq1 == N); + WSP_GGML_ASSERT(nek1 == N + P); + WSP_GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb0 <= nb1); + WSP_GGML_ASSERT(nb1 <= nb2); + WSP_GGML_ASSERT(nb2 <= nb3); + + if (params->type == WSP_GGML_TASK_INIT) { + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using wsp_ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (WSP_GGML_VEC_DOT_UNROLL > 2 || nek1 % WSP_GGML_VEC_DOT_UNROLL != 0) { + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + wsp_ggml_vec_dot_f16(neq0, + S + i1, + (wsp_ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (wsp_ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int64_t ic = 0; ic < nek1; ic += WSP_GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + wsp_ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (wsp_ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + wsp_ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + wsp_ggml_vec_max_f32(M, &max, S); + + wsp_ggml_float sum = 0.0; + { +#ifdef WSP_GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + wsp_ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[WSP_GGML_SOFT_MAX_UNROLL]; + wsp_ggml_float sump[WSP_GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += WSP_GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < WSP_GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + sump[j] += (wsp_ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < WSP_GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + wsp_ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + wsp_ggml_fp16_t * S16 = (wsp_ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int64_t i = 0; i < M; i++) { + S16[i] = WSP_GGML_FP32_TO_FP16(S[i]); + } + + if (WSP_GGML_VEC_DOT_UNROLL == 1 || (nev1 % WSP_GGML_VEC_DOT_UNROLL != 0)) { + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + wsp_ggml_vec_dot_f16(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (wsp_ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } else { + for (int64_t ic = 0; ic < nev1; ic += WSP_GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + wsp_ggml_vec_dot_f16_unroll(nek1, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } + } +} + +static void wsp_ggml_compute_forward_flash_attn( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * q, + const struct wsp_ggml_tensor * k, + const struct wsp_ggml_tensor * v, + const bool masked, + struct wsp_ggml_tensor * dst) { + switch (q->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); + } break; + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_flash_ff + +static void wsp_ggml_compute_forward_flash_ff_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * a, // F16 + const struct wsp_ggml_tensor * b0, // F16 fc_w + const struct wsp_ggml_tensor * b1, // F32 fc_b + const struct wsp_ggml_tensor * c0, // F16 proj_w + const struct wsp_ggml_tensor * c1, // F32 proj_b + struct wsp_ggml_tensor * dst) { + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_LOCALS(int64_t, nea, a, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nba, a, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = nea0; + //const int64_t N = nea1; + const int64_t M = neb01; + + WSP_GGML_ASSERT(ne0 == nea0); + WSP_GGML_ASSERT(ne1 == nea1); + WSP_GGML_ASSERT(ne2 == nea2); + + WSP_GGML_ASSERT(nba0 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nbb00 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nbb10 == sizeof(float)); + WSP_GGML_ASSERT(nbc00 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nbc10 == sizeof(float)); + + WSP_GGML_ASSERT(neb00 == D); + WSP_GGML_ASSERT(neb01 == M); + WSP_GGML_ASSERT(neb10 == M); + WSP_GGML_ASSERT(neb11 == 1); + + WSP_GGML_ASSERT(nec00 == M); + WSP_GGML_ASSERT(nec01 == D); + WSP_GGML_ASSERT(nec10 == D); + WSP_GGML_ASSERT(nec11 == 1); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb0 <= nb1); + WSP_GGML_ASSERT(nb1 <= nb2); + WSP_GGML_ASSERT(nb2 <= nb3); + + if (params->type == WSP_GGML_TASK_INIT) { + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // parallelize by a rows using wsp_ggml_vec_dot_f32 + + // total rows in a + const int nr = nea1*nea2*nea3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // a indices + const int ia3 = ir/(nea2*nea1); + const int ia2 = (ir - ia3*nea2*nea1)/nea1; + const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); + + float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); + + for (int64_t ic = 0; ic < neb01; ++ic) { + // b0 indices + const int ib03 = ia3; + const int ib02 = ia2; + const int ib01 = ic; + + // S indices + const int i1 = ib01; + + wsp_ggml_vec_dot_f16(nea0, + S + i1, + (wsp_ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), + (wsp_ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); + } + + wsp_ggml_vec_add_f32(neb01, S, S, (float *) b1->data); + //wsp_ggml_vec_gelu_f32(neb01, S, S); + + wsp_ggml_fp16_t * S16 = (wsp_ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + + for (int64_t i = 0; i < M; i++) { + S16[i] = WSP_GGML_FP32_TO_FP16(S[i]); + } + + wsp_ggml_vec_gelu_f16(neb01, S16, S16); + + { + // dst indices + const int i1 = ia1; + const int i2 = ia2; + const int i3 = ia3; + + for (int64_t ic = 0; ic < nec01; ++ic) { + + wsp_ggml_vec_dot_f16(neb01, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (wsp_ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), + S16); + } + + wsp_ggml_vec_add_f32(nec01, + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) c1->data); + } + } +} + +static void wsp_ggml_compute_forward_flash_ff( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * a, + const struct wsp_ggml_tensor * b0, + const struct wsp_ggml_tensor * b1, + const struct wsp_ggml_tensor * c0, + const struct wsp_ggml_tensor * c1, + struct wsp_ggml_tensor * dst) { + switch (b0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(false); // TODO + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_flash_attn_back + +static void wsp_ggml_compute_forward_flash_attn_back_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * q, + const struct wsp_ggml_tensor * k, + const struct wsp_ggml_tensor * v, + const struct wsp_ggml_tensor * d, + const bool masked, + struct wsp_ggml_tensor * dst) { + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ned, d, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nbd, d, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = wsp_ggml_up(M, WSP_GGML_SOFT_MAX_UNROLL); + const int mxDM = MAX(D, Mup); + + // WSP_GGML_ASSERT(ne0 == D); + // WSP_GGML_ASSERT(ne1 == N); + WSP_GGML_ASSERT(P >= 0); + + WSP_GGML_ASSERT(nbq0 == sizeof(float)); + WSP_GGML_ASSERT(nbk0 == sizeof(float)); + WSP_GGML_ASSERT(nbv0 == sizeof(float)); + + WSP_GGML_ASSERT(neq0 == D); + WSP_GGML_ASSERT(nek0 == D); + WSP_GGML_ASSERT(nev1 == D); + WSP_GGML_ASSERT(ned0 == D); + + WSP_GGML_ASSERT(neq1 == N); + WSP_GGML_ASSERT(nek1 == N + P); + WSP_GGML_ASSERT(nev1 == D); + WSP_GGML_ASSERT(ned1 == N); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb0 <= nb1); + WSP_GGML_ASSERT(nb1 <= nb2); + WSP_GGML_ASSERT(nb2 <= nb3); + + if (params->type == WSP_GGML_TASK_INIT) { + if (ith == 0) { + memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); + } + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using wsp_ggml_vec_dot_f32 + + // total rows in q + const int nr = neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2); + const int iq2 = ir - iq3*neq2; + for ( int iq1 = 0; iq1 < neq1; ++iq1) { + + + // not sure about CACHE_LINE_SIZE_F32.. + // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? + float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); + float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + wsp_ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + wsp_ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + wsp_ggml_vec_max_f32(M, &max, S); + + wsp_ggml_float sum = 0.0; + { +#ifdef WSP_GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(SM, 1, &max, SM, 1, Mup); + vvexpf(SM, SM, &Mup); + wsp_ggml_vec_sum_f32(Mup, &sum, SM); +#else + uint16_t scvt[WSP_GGML_SOFT_MAX_UNROLL]; + wsp_ggml_float sump[WSP_GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += WSP_GGML_SOFT_MAX_UNROLL) { + float * SR = S + i; + float * SW = SM + i; + + for (int j = 0; j < WSP_GGML_SOFT_MAX_UNROLL; ++j) { + if (SR[j] == -INFINITY) { + SW[j] = 0.0f; + } else { + wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(SR[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + sump[j] += (wsp_ggml_float)val; + SW[j] = val; + } + } + } + + for (int i = 0; i < WSP_GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + wsp_ggml_vec_scale_f32(M, SM, sum); + + } + + // step-by-step explanation + { + // forward-process shape grads from backward process + // parallel_for iq2,iq3: + // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur] + // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] + // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur] + // for iq1: + // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur + // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur + // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 + // S0 = -Inf [D,1,1,1] + // ~S1[i] = dot(kcur[:D,i], qcur) + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale + // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) + // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3] + // ~dst[i,iq1,iq2,iq3] = S5[i] ^ + // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3] + // dst backward-/ grad[dst] = d + // + // output gradients with their dependencies: + // + // grad[kcur] = grad[S1].T @ qcur + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S4] = grad[S5] @ vcur + // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur + // grad[qcur] = grad[S1] @ kcur + // grad[vcur] = grad[S5].T @ S4 + // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 + // + // in post-order: + // + // S1 = qcur @ kcur.T + // S2 = S1 * scale + // S3 = diag_mask_inf(S2, P) + // S4 = softmax(S3) + // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[qcur] = grad[S1] @ kcur + // grad[kcur] = grad[S1].T @ qcur + // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 + // + // using less variables (SM=S4): + // + // S = diag_mask_inf(qcur @ kcur.T * scale, P) + // SM = softmax(S) + // S = d[:D,iq1,iq2,iq3] @ vcur + // dot_SM_gradSM = dot(SM, S) + // S = SM * (S - dot(SM, S)) + // S = diag_mask_zero(S, P) * scale + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM + } + + // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur + // S = d[:D,iq1,iq2,iq3] @ vcur + // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3] + wsp_ggml_vec_set_f32(M, S, 0); + for (int64_t ic = 0; ic < D; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + wsp_ggml_vec_mad_f32(M, + S, + (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); + } + + // S = SM * (S - dot(SM, S)) + float dot_SM_gradSM = 0; + wsp_ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S); + wsp_ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); + wsp_ggml_vec_mul_f32 (M, S, S, SM); + + // S = diag_mask_zero(S, P) * scale + if (masked) { + // for (int64_t i = P + iq1 + 1; i < M; i++) { + // S[i] = 0; + // } + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = 0; + } + } + } + wsp_ggml_vec_scale_f32(M, S, scale); + + void * grad_q = (char *) dst->data; + void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3; + void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3; + + const size_t nbgq1 = nb0*neq0; + const size_t nbgq2 = nb0*neq0*neq1; + const size_t nbgq3 = nb0*neq0*neq1*neq2; + + const size_t nbgk1 = nb0*nek0; + const size_t nbgk2 = nb0*nek0*nek1; + const size_t nbgk3 = nb0*nek0*nek1*neq2; + + const size_t nbgv1 = nb0*nev0; + const size_t nbgv2 = nb0*nev0*nev1; + const size_t nbgv3 = nb0*nev0*nev1*neq2; + + // S shape [M,1] + // SM shape [M,1] + // kcur shape [D,M] + // qcur shape [D,1] + // vcur shape [M,D] + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] + // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic] + // + //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T) + //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T) + for (int64_t ic = 0; ic < M; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + wsp_ggml_vec_mad_f32(D, + (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)), + (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)), + S[ic]); + } + + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] + // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] + for (int64_t ic = 0; ic < M; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // wsp_ggml_vec_set_f32(D, + // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), + // 0); + wsp_ggml_vec_mad_f32(D, + (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), + (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)), + S[ic]); + } + + // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM + // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M] + // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M] + for (int64_t ic = 0; ic < D; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // wsp_ggml_vec_set_f32(M, + // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), + // 0); + wsp_ggml_vec_mad_f32(M, + (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), + SM, + *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); + } + } + } +} + +static void wsp_ggml_compute_forward_flash_attn_back( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * q, + const struct wsp_ggml_tensor * k, + const struct wsp_ggml_tensor * v, + const struct wsp_ggml_tensor * d, + const bool masked, + struct wsp_ggml_tensor * dst) { + switch (q->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_win_part + +static void wsp_ggml_compute_forward_win_part_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + + const int32_t nep0 = ((const int32_t *)(opt0->data))[0]; + const int32_t nep1 = ((const int32_t *)(opt0->data))[1]; + const int32_t w = ((const int32_t *)(opt0->data))[2]; + + assert(ne00 == ne0); + assert(ne3 == nep0*nep1); + + // TODO: optimize / multi-thread + for (int py = 0; py < nep1; ++py) { + for (int px = 0; px < nep0; ++px) { + const int64_t i3 = py*nep0 + px; + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; + const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + ((float *) dst->data)[i] = 0.0f; + } else { + ((float *) dst->data)[i] = ((float *) src0->data)[j]; + } + } + } + } + } + } +} + +static void wsp_ggml_compute_forward_win_part( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_win_part_f32(params, src0, opt0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_win_unpart + +static void wsp_ggml_compute_forward_win_unpart_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + + const int32_t w = ((const int32_t *)(opt0->data))[0]; + + // padding + const int px = (w - ne1%w)%w; + //const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + //const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + + // TODO: optimize / multi-thread + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; + const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; + + ((float *) dst->data)[j] = ((float *) src0->data)[i]; + } + } + } +} + +static void wsp_ggml_compute_forward_win_unpart( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_map_unary + +static void wsp_ggml_compute_forward_map_unary_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst, + const wsp_ggml_unary_op_f32_t fun) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + + +static void wsp_ggml_compute_forward_map_unary( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst, + const wsp_ggml_unary_op_f32_t fun) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_map_unary_f32(params, src0, dst, fun); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_map_binary + +static void wsp_ggml_compute_forward_map_binary_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst, + const wsp_ggml_binary_op_f32_t fun) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + + +static void wsp_ggml_compute_forward_map_binary( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst, + const wsp_ggml_binary_op_f32_t fun) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_map_custom1 + +static void wsp_ggml_compute_forward_map_custom1_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * dst, + const wsp_ggml_custom1_op_f32_t fun) { + assert(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + fun(dst, a); +} + + +static void wsp_ggml_compute_forward_map_custom1( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * dst, + const wsp_ggml_custom1_op_f32_t fun) { + switch (a->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_map_custom1_f32(params, a, dst, fun); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_map_custom2 + +static void wsp_ggml_compute_forward_map_custom2_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * a, + const struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * dst, + const wsp_ggml_custom2_op_f32_t fun) { + assert(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + fun(dst, a, b); +} + + +static void wsp_ggml_compute_forward_map_custom2( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * a, + const struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * dst, + const wsp_ggml_custom2_op_f32_t fun) { + switch (a->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_map_custom2_f32(params, a, b, dst, fun); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_map_custom3 + +static void wsp_ggml_compute_forward_map_custom3_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * a, + const struct wsp_ggml_tensor * b, + const struct wsp_ggml_tensor * c, + struct wsp_ggml_tensor * dst, + const wsp_ggml_custom3_op_f32_t fun) { + assert(params->ith == 0); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + fun(dst, a, b, c); +} + + +static void wsp_ggml_compute_forward_map_custom3( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * a, + const struct wsp_ggml_tensor * b, + const struct wsp_ggml_tensor * c, + struct wsp_ggml_tensor * dst, + const wsp_ggml_custom3_op_f32_t fun) { + switch (a->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_map_custom3_f32(params, a, b, c, dst, fun); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_cross_entropy_loss + +static void wsp_ggml_compute_forward_cross_entropy_loss_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1)); + WSP_GGML_ASSERT(wsp_ggml_is_scalar(dst)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1)); + + const int ith = params->ith; + const int nth = params->nth; + + float * sums = (float *) params->wdata; + + // TODO: handle transposed/permuted matrices + const int nc = src0->ne[0]; + const int nr = wsp_ggml_nrows(src0); + + if (params->type == WSP_GGML_TASK_INIT) { + if (ith == 0) { + memset(sums, 0, sizeof(float) * (nth + nth * nc)); + } + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + if (ith == 0) { + float * dp = (float *) dst->data; + wsp_ggml_vec_sum_f32(nth, dp, sums); + dp[0] *= -1.0f; + } + return; + } + + const double eps = 1e-9; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); + float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + float * st = (float *) params->wdata + nth + ith*nc; + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + // soft_max + wsp_ggml_float sum = 0.0; + { + float max = -INFINITY; + wsp_ggml_vec_max_f32(nc, &max, s0); + + uint16_t scvt; + for (int i = 0; i < nc; i++) { + if (s0[i] == -INFINITY) { + st[i] = 0.0f; + } else { + // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max); + wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(s0[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt]); + sum += (wsp_ggml_float)val; + st[i] = val; + } + } + + assert(sum > 0.0); + // sum = 1.0/sum; + } + // avoid log(0) by rescaling from [0..1] to [eps..1] + sum = (1.0 - eps) / sum; + wsp_ggml_vec_scale_f32(nc, st, sum); + wsp_ggml_vec_add1_f32(nc, st, st, eps); + wsp_ggml_vec_log_f32(nc, st, st); + wsp_ggml_vec_mul_f32(nc, st, st, s1); + + wsp_ggml_vec_sum_f32(nc, sums + ith, st); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(st[i])); + assert(!isinf(st[i])); + } +#endif + } + +} + +static void wsp_ggml_compute_forward_cross_entropy_loss( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +// wsp_ggml_compute_forward_cross_entropy_loss_back + +static void wsp_ggml_compute_forward_cross_entropy_loss_back_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(opt0)); + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst)); + + const int64_t ith = params->ith; + const int64_t nth = params->nth; + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const float eps = 1e-9f; + + // TODO: handle transposed/permuted matrices + const int64_t nc = src0->ne[0]; + const int64_t nr = wsp_ggml_nrows(src0); + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + float * d = (float *) opt0->data; + + for (int64_t i1 = ir0; i1 < ir1; i1++) { + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); + float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + float * sm = (float *) params->wdata + ith*nc; + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + // step by step explanation: + { + //float * sums = (float *) params->wdata; + + // forward pass with annotated gradients from backward pass + // (built by going in reverse operation order, adding to gradients of current operation args) + // st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum + // from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1])) + // wsp_ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps) + // wsp_ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3] + // wsp_ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3 + // wsp_ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1 + // wsp_ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]] + // wsp_ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel] + + // substitute into grad[st1], because we can reuse softmax_back from this point on + // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps)) + // postorder: + // grad[st1] := softmax(s0) + // grad[st1] := grad[st1]*(1.0 - eps) + // grad[st1] := grad[st1] + eps + // grad[st1] := s1 / grad[st1] + // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel] + + // src0 gradients by going through softmax_back + // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1])) + // from softmax_back: + // dxk = yk * (dyk - dot(y, dy)) + // dot_y_dy := dot(y, dy) + // dx := dy + // dx := dx - dot_y_dy + // dx := dx * y + // postorder: + // dot_st1_dst1 := dot(st1, grad[st1]) + // grad[s0] := grad[st1] + // grad[s0] := grad[s0] - dot_st1_dst1 + // grad[s0] := grad[s0] * st1 + + // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1] + // sm := softmax(s0) + // grad[s0] := sm*(1.0 - eps) + // grad[s0] := grad[s0] + eps + // grad[s0] := s1 / grad[s0] + // grad[s0] := grad[s0]*(1.0-eps)*-grad[cel] + // dot_st1_dst1 := dot(sm, grad[s0]) + // grad[s0] := grad[s0] - dot_st1_dst1 + // grad[s0] := grad[s0] * sm + } + + // soft_max + wsp_ggml_float sum = 0.0; + { + float max = -INFINITY; + wsp_ggml_vec_max_f32(nc, &max, s0); + + uint16_t scvt; + for (int i = 0; i < nc; i++) { + if (s0[i] == -INFINITY) { + sm[i] = 0.0f; + } else { + // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max); + wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(s0[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt]); + sum += (wsp_ggml_float)val; + sm[i] = val; + } + } + + assert(sum > 0.0); + sum = 1.0/sum; + } + + float dot_st1_dst1 = 0; + wsp_ggml_vec_scale_f32(nc, sm, sum); + wsp_ggml_vec_cpy_f32 (nc, ds0, sm); + wsp_ggml_vec_scale_f32(nc, ds0, (1.0f - eps)); + wsp_ggml_vec_add1_f32 (nc, ds0, ds0, eps); + wsp_ggml_vec_div_f32 (nc, ds0, s1, ds0); + wsp_ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]); + wsp_ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0); + wsp_ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1); + wsp_ggml_vec_mul_f32 (nc, ds0, ds0, sm); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(sm[i])); + assert(!isinf(sm[i])); + assert(!isnan(ds0[i])); + assert(!isinf(ds0[i])); + } +#endif + } +} + +static void wsp_ggml_compute_forward_cross_entropy_loss_back( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + const struct wsp_ggml_tensor * opt0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + + +///////////////////////////////// + +static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * tensor) { + WSP_GGML_ASSERT(params); + +#ifdef WSP_GGML_USE_CUBLAS + bool skip_cpu = wsp_ggml_cuda_compute_forward(params, tensor); + if (skip_cpu) { + return; + } + WSP_GGML_ASSERT(tensor->src0 == NULL || tensor->src0->backend == WSP_GGML_BACKEND_CPU); + WSP_GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == WSP_GGML_BACKEND_CPU); +#endif // WSP_GGML_USE_CUBLAS + + switch (tensor->op) { + case WSP_GGML_OP_DUP: + { + wsp_ggml_compute_forward_dup(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_ADD: + { + wsp_ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_ADD1: + { + wsp_ggml_compute_forward_add1(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_ACC: + { + wsp_ggml_compute_forward_acc(params, tensor->src0, tensor->src1, tensor->opt[0], tensor); + } break; + case WSP_GGML_OP_SUB: + { + wsp_ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_MUL: + { + wsp_ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_DIV: + { + wsp_ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_SQR: + { + wsp_ggml_compute_forward_sqr(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_SQRT: + { + wsp_ggml_compute_forward_sqrt(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_LOG: + { + wsp_ggml_compute_forward_log(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_SUM: + { + wsp_ggml_compute_forward_sum(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_SUM_ROWS: + { + wsp_ggml_compute_forward_sum_rows(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_MEAN: + { + wsp_ggml_compute_forward_mean(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_ARGMAX: + { + wsp_ggml_compute_forward_argmax(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_REPEAT: + { + wsp_ggml_compute_forward_repeat(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_REPEAT_BACK: + { + wsp_ggml_compute_forward_repeat_back(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_ABS: + { + wsp_ggml_compute_forward_abs(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_SGN: + { + wsp_ggml_compute_forward_sgn(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_NEG: + { + wsp_ggml_compute_forward_neg(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_STEP: + { + wsp_ggml_compute_forward_step(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_TANH: + { + wsp_ggml_compute_forward_tanh(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_ELU: + { + wsp_ggml_compute_forward_elu(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_RELU: + { + wsp_ggml_compute_forward_relu(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_GELU: + { + wsp_ggml_compute_forward_gelu(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_GELU_QUICK: + { + wsp_ggml_compute_forward_gelu_quick(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_SILU: + { + wsp_ggml_compute_forward_silu(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_SILU_BACK: + { + wsp_ggml_compute_forward_silu_back(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_NORM: + { + wsp_ggml_compute_forward_norm(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_RMS_NORM: + { + wsp_ggml_compute_forward_rms_norm(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_RMS_NORM_BACK: + { + wsp_ggml_compute_forward_rms_norm_back(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_MUL_MAT: + { + wsp_ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_OUT_PROD: + { + wsp_ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_SCALE: + { + wsp_ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_SET: + { + wsp_ggml_compute_forward_set(params, tensor->src0, tensor->src1, tensor->opt[0], tensor); + } break; + case WSP_GGML_OP_CPY: + { + wsp_ggml_compute_forward_cpy(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_CONT: + { + wsp_ggml_compute_forward_cont(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_RESHAPE: + { + wsp_ggml_compute_forward_reshape(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_VIEW: + { + wsp_ggml_compute_forward_view(params, tensor->src0); + } break; + case WSP_GGML_OP_PERMUTE: + { + wsp_ggml_compute_forward_permute(params, tensor->src0); + } break; + case WSP_GGML_OP_TRANSPOSE: + { + wsp_ggml_compute_forward_transpose(params, tensor->src0); + } break; + case WSP_GGML_OP_GET_ROWS: + { + wsp_ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_GET_ROWS_BACK: + { + wsp_ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor); + } break; + case WSP_GGML_OP_DIAG: + { + wsp_ggml_compute_forward_diag(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_DIAG_MASK_INF: + { + wsp_ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_DIAG_MASK_ZERO: + { + wsp_ggml_compute_forward_diag_mask_zero(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_SOFT_MAX: + { + wsp_ggml_compute_forward_soft_max(params, tensor->src0, tensor); + } break; + case WSP_GGML_OP_SOFT_MAX_BACK: + { + wsp_ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_ROPE: + { + wsp_ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_ROPE_BACK: + { + wsp_ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_ALIBI: + { + wsp_ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_CLAMP: + { + wsp_ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor); + } break; + case WSP_GGML_OP_CONV_1D: + { + wsp_ggml_compute_forward_conv_1d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor); + } break; + case WSP_GGML_OP_CONV_2D: + { + wsp_ggml_compute_forward_conv_2d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor); + } break; + case WSP_GGML_OP_FLASH_ATTN: + { + const int32_t t = wsp_ggml_get_i32_1d(tensor->opt[1], 0); + WSP_GGML_ASSERT(t == 0 || t == 1); + const bool masked = t != 0; + wsp_ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor); + } break; + case WSP_GGML_OP_FLASH_FF: + { + wsp_ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); + } break; + case WSP_GGML_OP_FLASH_ATTN_BACK: + { + int32_t t = wsp_ggml_get_i32_1d(tensor->opt[2], 0); + WSP_GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + wsp_ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor); + } break; + case WSP_GGML_OP_WIN_PART: + { + wsp_ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor); + } break; + case WSP_GGML_OP_WIN_UNPART: + { + wsp_ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor); + } break; + case WSP_GGML_OP_MAP_UNARY: + { + const wsp_ggml_unary_op_f32_t fun = *((wsp_ggml_unary_op_f32_t *)tensor->opt[0]->data); + wsp_ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun); + } + break; + case WSP_GGML_OP_MAP_BINARY: + { + const wsp_ggml_binary_op_f32_t fun = *((wsp_ggml_binary_op_f32_t *)tensor->opt[0]->data); + wsp_ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun); + } + break; + case WSP_GGML_OP_MAP_CUSTOM1: + { + const wsp_ggml_custom1_op_f32_t fun = *((wsp_ggml_custom1_op_f32_t *)tensor->opt[0]->data); + wsp_ggml_compute_forward_map_custom1(params, tensor->src0, tensor, fun); + } + break; + case WSP_GGML_OP_MAP_CUSTOM2: + { + const wsp_ggml_custom2_op_f32_t fun = *((wsp_ggml_custom2_op_f32_t *)tensor->opt[0]->data); + wsp_ggml_compute_forward_map_custom2(params, tensor->src0, tensor->src1, tensor, fun); + } + break; + case WSP_GGML_OP_MAP_CUSTOM3: + { + const wsp_ggml_custom3_op_f32_t fun = *((wsp_ggml_custom3_op_f32_t *)tensor->opt[0]->data); + wsp_ggml_compute_forward_map_custom3(params, tensor->src0, tensor->src1, tensor->opt[1], tensor, fun); + } + break; + case WSP_GGML_OP_CROSS_ENTROPY_LOSS: + { + wsp_ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor); + } + break; + case WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + wsp_ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor); + } + break; + case WSP_GGML_OP_NONE: + { + // nop + } break; + case WSP_GGML_OP_COUNT: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor, bool inplace) { + struct wsp_ggml_tensor * src0 = tensor->src0; + struct wsp_ggml_tensor * src1 = tensor->src1; + + switch (tensor->op) { + case WSP_GGML_OP_DUP: + { + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case WSP_GGML_OP_ADD: + { + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = wsp_ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case WSP_GGML_OP_ADD1: + { + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = wsp_ggml_add_impl(ctx, + src1->grad, + wsp_ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean + inplace); + } + } break; + case WSP_GGML_OP_ACC: + { + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + WSP_GGML_ASSERT(wsp_ggml_nelements(tensor->opt[0]) == 5); + WSP_GGML_ASSERT(tensor->opt[0]->type == WSP_GGML_TYPE_I32); + const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0]; + const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1]; + const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2]; + const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3]; + + struct wsp_ggml_tensor * tensor_grad_view = wsp_ggml_view_4d(ctx, + tensor->grad, + src1->grad->ne[0], + src1->grad->ne[1], + src1->grad->ne[2], + src1->grad->ne[3], + nb1, nb2, nb3, offset); + + src1->grad = + wsp_ggml_add_impl(ctx, + src1->grad, + wsp_ggml_reshape(ctx, + wsp_ggml_cont(ctx, tensor_grad_view), + src1->grad), + inplace); + } + } break; + case WSP_GGML_OP_SUB: + { + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = wsp_ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case WSP_GGML_OP_MUL: + { + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_mul(ctx, src1, tensor->grad), + inplace); + } + if (src1->grad) { + src1->grad = + wsp_ggml_add_impl(ctx, + src1->grad, + wsp_ggml_mul(ctx, src0, tensor->grad), + inplace); + } + } break; + case WSP_GGML_OP_DIV: + { + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_div(ctx, tensor->grad, src1), + inplace); + } + if (src1->grad) { + src1->grad = + wsp_ggml_sub_impl(ctx, + src1->grad, + wsp_ggml_mul(ctx, + tensor->grad, + wsp_ggml_div(ctx, tensor, src1)), + inplace); + } + } break; + case WSP_GGML_OP_SQR: + { + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_scale(ctx, + wsp_ggml_mul(ctx, src0, tensor->grad), + wsp_ggml_new_f32(ctx, 2.0f)), + inplace); + } + } break; + case WSP_GGML_OP_SQRT: + { + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_scale(ctx, + wsp_ggml_div(ctx, + tensor->grad, + tensor), + wsp_ggml_new_f32(ctx, 0.5f)), + inplace); + } + } break; + case WSP_GGML_OP_LOG: + { + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_div(ctx, + tensor->grad, + src0), + inplace); + } + } break; + case WSP_GGML_OP_SUM: + { + if (src0->grad) { + src0->grad = + wsp_ggml_add1_impl(ctx, + src0->grad, + tensor->grad, + inplace); + } + } break; + case WSP_GGML_OP_SUM_ROWS: + { + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_repeat(ctx, + tensor->grad, + src0->grad), + inplace); + } + } break; + case WSP_GGML_OP_MEAN: + case WSP_GGML_OP_ARGMAX: + { + WSP_GGML_ASSERT(false); // TODO: implement + } break; + case WSP_GGML_OP_REPEAT: + { + // necessary for llama + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_repeat_back(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case WSP_GGML_OP_REPEAT_BACK: + { + if (src0->grad) { + // TODO: test this + src0->grad = wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_repeat(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case WSP_GGML_OP_ABS: + { + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_mul(ctx, + wsp_ggml_sgn(ctx, src0), + tensor->grad), + inplace); + } + } break; + case WSP_GGML_OP_SGN: + { + if (src0->grad) { + // noop + } + } break; + case WSP_GGML_OP_NEG: + { + if (src0->grad) { + src0->grad = wsp_ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case WSP_GGML_OP_STEP: + { + if (src0->grad) { + // noop + } + } break; + case WSP_GGML_OP_TANH: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_ELU: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_RELU: + { + if (src0->grad) { + src0->grad = wsp_ggml_sub_impl(ctx, + src0->grad, + wsp_ggml_mul(ctx, + wsp_ggml_step(ctx, src0), + tensor->grad), + inplace); + } + } break; + case WSP_GGML_OP_GELU: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_GELU_QUICK: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_SILU: + { + // necessary for llama + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_silu_back(ctx, src0, tensor->grad), + inplace); + } + } break; + case WSP_GGML_OP_SILU_BACK: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_NORM: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_RMS_NORM: + { + // necessary for llama + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_rms_norm_back(ctx, src0, tensor->grad), + inplace); + } + } break; + case WSP_GGML_OP_RMS_NORM_BACK: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_MUL_MAT: + { + // https://cs231n.github.io/optimization-2/#staged + // # forward pass + // s0 = np.random.randn(5, 10) + // s1 = np.random.randn(10, 3) + // t = s0.dot(s1) + + // # now suppose we had the gradient on t from above in the circuit + // dt = np.random.randn(*t.shape) # same shape as t + // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix + // ds1 = t.T.dot(dt) + + // tensor.shape [m,p] + // src0.shape [n,m] + // src1.shape [n,p] + + // necessary for llama + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_out_prod(ctx, // [n,m] + src1, // [n,p] + tensor->grad), // [m,p] + inplace); + } + if (src1->grad) { + src1->grad = + wsp_ggml_add_impl(ctx, + src1->grad, + // wsp_ggml_mul_mat(ctx, // [n,p] + // wsp_ggml_cont(ctx, // [m,n] + // wsp_ggml_transpose(ctx, src0)), // [m,n] + // tensor->grad), // [m,p] + + // // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // // avoid transpose of src0, rather transpose smaller tensor->grad + // // and then use wsp_ggml_out_prod + wsp_ggml_out_prod(ctx, // [n,p] + src0, // [n,m] + wsp_ggml_transpose(ctx, // [p,m] + tensor->grad)), // [m,p] + inplace); + } + } break; + case WSP_GGML_OP_OUT_PROD: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_SCALE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_scale_impl(ctx, tensor->grad, src1, false), + inplace); + } + if (src1->grad) { + src1->grad = + wsp_ggml_add_impl(ctx, + src1->grad, + wsp_ggml_sum(ctx, wsp_ggml_mul_impl(ctx, tensor->grad, src0, false)), + inplace); + } + } break; + case WSP_GGML_OP_SET: + { + WSP_GGML_ASSERT(wsp_ggml_nelements(tensor->opt[0]) == 5); + WSP_GGML_ASSERT(tensor->opt[0]->type == WSP_GGML_TYPE_I32); + const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0]; + const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1]; + const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2]; + const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3]; + + struct wsp_ggml_tensor * tensor_grad_view = NULL; + + if (src0->grad || src1->grad) { + WSP_GGML_ASSERT(src0->type == tensor->type); + WSP_GGML_ASSERT(tensor->grad->type == tensor->type); + WSP_GGML_ASSERT(tensor->grad->type == src1->grad->type); + + tensor_grad_view = wsp_ggml_view_4d(ctx, + tensor->grad, + src1->grad->ne[0], + src1->grad->ne[1], + src1->grad->ne[2], + src1->grad->ne[3], + nb1, nb2, nb3, offset); + } + + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_acc_impl(ctx, + tensor->grad, + wsp_ggml_neg(ctx, tensor_grad_view), + nb1, nb2, nb3, offset, false), + inplace); + } + + if (src1->grad) { + src1->grad = + wsp_ggml_add_impl(ctx, + src1->grad, + wsp_ggml_reshape(ctx, + wsp_ggml_cont(ctx, tensor_grad_view), + src1->grad), + inplace); + } + } break; + case WSP_GGML_OP_CPY: + { + // necessary for llama + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 + if (src0->grad) { + // dsrc0 = dtensor * 1 + src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + // dsrc1 = dtensor * 0 -> noop + } + } break; + case WSP_GGML_OP_CONT: + { + // same as cpy + if (src0->grad) { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0->grad)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(tensor->grad)); + src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case WSP_GGML_OP_RESHAPE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_reshape(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case WSP_GGML_OP_VIEW: + { + // necessary for llama + if (src0->grad) { + size_t offset; + + WSP_GGML_ASSERT(sizeof(offset) <= wsp_ggml_nbytes(tensor->opt[0])); + memcpy(&offset, tensor->opt[0]->data, sizeof(offset)); + + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; + + if (src0->type != src0->grad->type) { + // gradient is typically F32, but src0 could be other type + size_t ng = wsp_ggml_element_size(src0->grad); + size_t n0 = wsp_ggml_element_size(src0); + WSP_GGML_ASSERT(offset % n0 == 0); + WSP_GGML_ASSERT(nb1 % n0 == 0); + WSP_GGML_ASSERT(nb2 % n0 == 0); + WSP_GGML_ASSERT(nb3 % n0 == 0); + offset = (offset / n0) * ng; + nb1 = (nb1 / n0) * ng; + nb2 = (nb2 / n0) * ng; + nb3 = (nb3 / n0) * ng; + } + + src0->grad = wsp_ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace); + } + } break; + case WSP_GGML_OP_PERMUTE: + { + // necessary for llama + if (src0->grad) { + int32_t * axes = (int32_t *) tensor->opt[0]->data; + int axis0 = axes[0] & 0x3; + int axis1 = axes[1] & 0x3; + int axis2 = axes[2] & 0x3; + int axis3 = axes[3] & 0x3; + int axes_backward[4] = {0,0,0,0}; + axes_backward[axis0] = 0; + axes_backward[axis1] = 1; + axes_backward[axis2] = 2; + axes_backward[axis3] = 3; + src0->grad = + wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_permute(ctx, + tensor->grad, + axes_backward[0], + axes_backward[1], + axes_backward[2], + axes_backward[3]), + inplace); + } + } break; + case WSP_GGML_OP_TRANSPOSE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_transpose(ctx, tensor->grad), + inplace); + } + } break; + case WSP_GGML_OP_GET_ROWS: + { + // necessary for llama (only for tokenizer) + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), + inplace); + } + if (src1->grad) { + // noop + } + } break; + case WSP_GGML_OP_GET_ROWS_BACK: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_DIAG: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_DIAG_MASK_INF: + { + // necessary for llama + if (src0->grad) { + assert(src1->type == WSP_GGML_TYPE_I32); + assert(wsp_ggml_nelements(src1) == 2); + const int n_past = ((int32_t *) src1->data)[0]; + src0->grad = + wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), + inplace); + } + if (src1->grad) { + // noop + } + } break; + case WSP_GGML_OP_DIAG_MASK_ZERO: + { + // necessary for llama + if (src0->grad) { + assert(src1->type == WSP_GGML_TYPE_I32); + assert(wsp_ggml_nelements(src1) == 2); + const int n_past = ((int32_t *) src1->data)[0]; + src0->grad = + wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), + inplace); + } + if (src1->grad) { + // noop + } + } break; + case WSP_GGML_OP_SOFT_MAX: + { + // necessary for llama + if (src0->grad) { + src0->grad = + wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_soft_max_back(ctx, tensor->grad, tensor), + inplace); + } + + } break; + case WSP_GGML_OP_SOFT_MAX_BACK: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_ROPE: + { + // necessary for llama + if (src0->grad) { + assert(src1->type == WSP_GGML_TYPE_I32); + assert(wsp_ggml_nelements(src1) == 4); + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + src0->grad = wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_rope_back(ctx, + tensor->grad, + n_past, + n_dims, + mode), + inplace); + } + if (src1->grad) { + // noop + } + } break; + case WSP_GGML_OP_ROPE_BACK: + { + if (src0->grad) { + assert(src1->type == WSP_GGML_TYPE_I32); + assert(wsp_ggml_nelements(src1) == 4); + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; + src0->grad = wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_rope(ctx, + tensor->grad, + n_past, + n_dims, + mode, + n_ctx), + inplace); + } + if (src1->grad) { + // noop + } + } break; + case WSP_GGML_OP_ALIBI: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_CLAMP: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_CONV_1D: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_CONV_2D: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_FLASH_ATTN: + { + struct wsp_ggml_tensor * flash_grad = NULL; + if (src0->grad || src1->grad || tensor->opt[0]->grad) { + int32_t t = wsp_ggml_get_i32_1d(tensor->opt[1], 0); + WSP_GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + flash_grad = + wsp_ggml_flash_attn_back(ctx, + src0, + src1, + tensor->opt[0], + tensor->grad, + masked); + } + + if (src0->grad) { + struct wsp_ggml_tensor * grad_q = NULL; + const size_t nb0 = flash_grad->nb[0]; + const size_t offset = 0; + switch(src0->n_dims) { + case 2: + { + grad_q = wsp_ggml_view_2d(ctx, + flash_grad, + src0->ne[0], + src0->ne[1], + nb0*src0->ne[0], + offset); + } break; + case 3: + { + grad_q = wsp_ggml_view_3d(ctx, + flash_grad, + src0->ne[0], + src0->ne[1], + src0->ne[2], + nb0*src0->ne[0], + nb0*src0->ne[0]*src0->ne[1], + offset); + } break; + case 4: + { + grad_q = wsp_ggml_view_4d(ctx, + flash_grad, + src0->ne[0], + src0->ne[1], + src0->ne[2], + src0->ne[3], + nb0*src0->ne[0], + nb0*src0->ne[0]*src0->ne[1], + nb0*src0->ne[0]*src0->ne[1]*src0->ne[2], + offset); + } break; + } + + src0->grad = wsp_ggml_add_impl(ctx, + src0->grad, + grad_q, + inplace); + } + + if (src1->grad) { + struct wsp_ggml_tensor * grad_k = NULL; + const size_t nb0 = flash_grad->nb[0]; + const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]; + switch(src1->n_dims) { + case 2: + { + grad_k = wsp_ggml_view_2d(ctx, + flash_grad, + src1->ne[0], + src1->ne[1], + nb0*src1->ne[0], + offset); + } break; + case 3: + { + grad_k = wsp_ggml_view_3d(ctx, + flash_grad, + src1->ne[0], + src1->ne[1], + src1->ne[2], + nb0*src1->ne[0], + nb0*src1->ne[0]*src1->ne[1], + offset); + } break; + case 4: + { + grad_k = wsp_ggml_view_4d(ctx, + flash_grad, + src1->ne[0], + src1->ne[1], + src1->ne[2], + src1->ne[3], + nb0*src1->ne[0], + nb0*src1->ne[0]*src1->ne[1], + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2], + offset); + } break; + } + + src1->grad = wsp_ggml_add_impl(ctx, + src1->grad, + grad_k, + inplace); + } + + struct wsp_ggml_tensor * opt0 = tensor->opt[0]; + + if (opt0->grad) { + struct wsp_ggml_tensor * grad_v = NULL; + const size_t nb0 = flash_grad->nb[0]; + const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3] + + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3]; + switch(opt0->n_dims) { + case 2: + { + grad_v = wsp_ggml_view_2d(ctx, + flash_grad, + opt0->ne[0], + opt0->ne[1], + nb0*opt0->ne[0], + offset); + } break; + case 3: + { + grad_v = wsp_ggml_view_3d(ctx, + flash_grad, + opt0->ne[0], + opt0->ne[1], + opt0->ne[2], + nb0*opt0->ne[0], + nb0*opt0->ne[0]*opt0->ne[1], + offset); + } break; + case 4: + { + grad_v = wsp_ggml_view_4d(ctx, + flash_grad, + opt0->ne[0], + opt0->ne[1], + opt0->ne[2], + opt0->ne[3], + nb0*opt0->ne[0], + nb0*opt0->ne[0]*opt0->ne[1], + nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2], + offset); + } break; + } + + opt0->grad = wsp_ggml_add_impl(ctx, + opt0->grad, + grad_v, + inplace); + } + } break; + case WSP_GGML_OP_FLASH_FF: + { + WSP_GGML_ASSERT(false); // not supported + } break; + case WSP_GGML_OP_FLASH_ATTN_BACK: + { + WSP_GGML_ASSERT(false); // not supported + } break; + case WSP_GGML_OP_WIN_PART: + case WSP_GGML_OP_WIN_UNPART: + case WSP_GGML_OP_MAP_UNARY: + case WSP_GGML_OP_MAP_BINARY: + case WSP_GGML_OP_MAP_CUSTOM1: + case WSP_GGML_OP_MAP_CUSTOM2: + case WSP_GGML_OP_MAP_CUSTOM3: + { + WSP_GGML_ASSERT(false); // not supported + } break; + case WSP_GGML_OP_CROSS_ENTROPY_LOSS: + { + if (src0->grad) { + src0->grad = wsp_ggml_add_impl(ctx, + src0->grad, + wsp_ggml_cross_entropy_loss_back(ctx, + src0, + src1, + tensor->grad), + inplace); + } + } break; + case WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + WSP_GGML_ASSERT(false); // not supported + } break; + case WSP_GGML_OP_NONE: + { + // nop + } break; + case WSP_GGML_OP_COUNT: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +static void wsp_ggml_visit_parents(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * node) { + if (node->grad == NULL) { + // this usually happens when we generate intermediate nodes from constants in the backward pass + // it can also happen during forward pass, if the user performs computations with constants + if (node->op != WSP_GGML_OP_NONE) { + //WSP_GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); + } + } + + // check if already visited + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return; + } + } + + for (int i = 0; i < cgraph->n_leafs; i++) { + if (cgraph->leafs[i] == node) { + return; + } + } + + if (node->src0) { + wsp_ggml_visit_parents(cgraph, node->src0); + } + + if (node->src1) { + wsp_ggml_visit_parents(cgraph, node->src1); + } + + for (int i = 0; i < WSP_GGML_MAX_OPT; ++i) { + if (node->opt[i]) { + wsp_ggml_visit_parents(cgraph, node->opt[i]); + } + } + + if (node->op == WSP_GGML_OP_NONE && node->grad == NULL) { + // reached a leaf node, not part of the gradient graph (e.g. a constant) + WSP_GGML_ASSERT(cgraph->n_leafs < WSP_GGML_MAX_NODES); + + if (strlen(node->name) == 0) { + wsp_ggml_format_name(node, "leaf_%d", cgraph->n_leafs); + } + + cgraph->leafs[cgraph->n_leafs] = node; + cgraph->n_leafs++; + } else { + WSP_GGML_ASSERT(cgraph->n_nodes < WSP_GGML_MAX_NODES); + + if (strlen(node->name) == 0) { + wsp_ggml_format_name(node, "node_%d", cgraph->n_nodes); + } + + cgraph->nodes[cgraph->n_nodes] = node; + cgraph->grads[cgraph->n_nodes] = node->grad; + cgraph->n_nodes++; + } +} + +static void wsp_ggml_build_forward_impl(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor, bool expand) { + if (!expand) { + cgraph->n_nodes = 0; + cgraph->n_leafs = 0; + } + + const int n0 = cgraph->n_nodes; + UNUSED(n0); + + wsp_ggml_visit_parents(cgraph, tensor); + + const int n_new = cgraph->n_nodes - n0; + WSP_GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); + + if (n_new > 0) { + // the last added node should always be starting point + WSP_GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + } +} + +void wsp_ggml_build_forward_expand(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor) { + wsp_ggml_build_forward_impl(cgraph, tensor, true); +} + +struct wsp_ggml_cgraph wsp_ggml_build_forward(struct wsp_ggml_tensor * tensor) { + struct wsp_ggml_cgraph result = { + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.n_threads =*/ WSP_GGML_DEFAULT_N_THREADS, + /*.work_size =*/ 0, + /*.work =*/ NULL, + /*.nodes =*/ { NULL }, + /*.grads =*/ { NULL }, + /*.leafs =*/ { NULL }, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + + wsp_ggml_build_forward_impl(&result, tensor, false); + + return result; +} + +struct wsp_ggml_cgraph wsp_ggml_build_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, bool keep) { + struct wsp_ggml_cgraph result = *gf; + + WSP_GGML_ASSERT(gf->n_nodes > 0); + + // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph + if (keep) { + for (int i = 0; i < gf->n_nodes; i++) { + struct wsp_ggml_tensor * node = gf->nodes[i]; + + if (node->grad) { + node->grad = wsp_ggml_dup_tensor(ctx, node); + gf->grads[i] = node->grad; + } + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct wsp_ggml_tensor * node = gf->nodes[i]; + + // because we detached the grad nodes from the original graph, we can afford inplace operations + if (node->grad) { + wsp_ggml_compute_backward(ctx, node, keep); + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct wsp_ggml_tensor * node = gf->nodes[i]; + + if (node->is_param) { + WSP_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + wsp_ggml_build_forward_impl(&result, node->grad, true); + } + } + + return result; +} + +// +// thread data +// +// synchronization is done via busy loops +// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops +// + +#ifdef __APPLE__ + +//#include +// +//typedef os_unfair_lock wsp_ggml_lock_t; +// +//#define wsp_ggml_lock_init(x) UNUSED(x) +//#define wsp_ggml_lock_destroy(x) UNUSED(x) +//#define wsp_ggml_lock_lock os_unfair_lock_lock +//#define wsp_ggml_lock_unlock os_unfair_lock_unlock +// +//#define WSP_GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT + +typedef int wsp_ggml_lock_t; + +#define wsp_ggml_lock_init(x) UNUSED(x) +#define wsp_ggml_lock_destroy(x) UNUSED(x) +#define wsp_ggml_lock_lock(x) UNUSED(x) +#define wsp_ggml_lock_unlock(x) UNUSED(x) + +#define WSP_GGML_LOCK_INITIALIZER 0 + +typedef pthread_t wsp_ggml_thread_t; + +#define wsp_ggml_thread_create pthread_create +#define wsp_ggml_thread_join pthread_join + +#else + +//typedef pthread_spinlock_t wsp_ggml_lock_t; + +//#define wsp_ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) +//#define wsp_ggml_lock_destroy pthread_spin_destroy +//#define wsp_ggml_lock_lock pthread_spin_lock +//#define wsp_ggml_lock_unlock pthread_spin_unlock + +typedef int wsp_ggml_lock_t; + +#define wsp_ggml_lock_init(x) UNUSED(x) +#define wsp_ggml_lock_destroy(x) UNUSED(x) +#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) +#define wsp_ggml_lock_lock(x) _mm_pause() +#else +#define wsp_ggml_lock_lock(x) UNUSED(x) +#endif +#define wsp_ggml_lock_unlock(x) UNUSED(x) + +#define WSP_GGML_LOCK_INITIALIZER 0 + +typedef pthread_t wsp_ggml_thread_t; + +#define wsp_ggml_thread_create pthread_create +#define wsp_ggml_thread_join pthread_join + +#endif + +// Android's libc implementation "bionic" does not support setting affinity +#if defined(__linux__) && !defined(__BIONIC__) +void set_numa_thread_affinity(int thread_n, int n_threads) { + if (!wsp_ggml_is_numa()) { + return; + } + + // run thread on node_num thread_n / (threads per node) + const int node_num = thread_n / ((n_threads + g_state.numa.n_nodes - 1) / g_state.numa.n_nodes); + struct wsp_ggml_numa_node * node = &g_state.numa.nodes[node_num]; + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (size_t i = 0; i < node->n_cpus; ++i) { + CPU_SET_S(node->cpus[i], setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", + strerror(rv)); + } + + CPU_FREE(cpus); +} + +void clear_numa_thread_affinity(void) { + if (!wsp_ggml_is_numa()) { + return; + } + + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) { + CPU_SET_S(i, setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", + strerror(rv)); + } + + CPU_FREE(cpus); +} +#else +// TODO: Windows etc. +// (the linux implementation may also work on BSD, someone should test) +void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads); } +void clear_numa_thread_affinity(void) {} +#endif + +struct wsp_ggml_compute_state_shared { + struct wsp_ggml_cgraph * cgraph; + + int64_t perf_node_start_cycles; + int64_t perf_node_start_time_us; + + int n_threads; + + // synchronization primitives + atomic_int n_active; // num active threads + atomic_int node_n; // active graph node +}; + +struct wsp_ggml_compute_state { + wsp_ggml_thread_t thrd; + int ith; + struct wsp_ggml_compute_state_shared * shared; +}; + +static void wsp_ggml_graph_compute_perf_stats_node(struct wsp_ggml_tensor * node, const struct wsp_ggml_compute_state_shared * st) { + int64_t cycles_cur = wsp_ggml_perf_cycles() - st->perf_node_start_cycles; + int64_t time_us_cur = wsp_ggml_perf_time_us() - st->perf_node_start_time_us; + + node->perf_runs++; + node->perf_cycles += cycles_cur; + node->perf_time_us += time_us_cur; +} + +static thread_ret_t wsp_ggml_graph_compute_thread(void * data) { + struct wsp_ggml_compute_state * state = (struct wsp_ggml_compute_state *) data; + struct wsp_ggml_cgraph * cgraph = state->shared->cgraph; + + const int n_threads = state->shared->n_threads; + set_numa_thread_affinity(state->ith, n_threads); + + int node_n = -1; + + while (true) { + if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { + // all other threads are finished and spinning + // do finalize and init here so we don't have synchronize again + struct wsp_ggml_compute_params params = { + /*.type =*/ WSP_GGML_TASK_FINALIZE, + /*.ith =*/ 0, + /*.nth =*/ 0, + /*.wsize =*/ cgraph->work ? wsp_ggml_nbytes(cgraph->work) : 0, + /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, + }; + + if (node_n != -1) { + /* FINALIZE */ + struct wsp_ggml_tensor * node = state->shared->cgraph->nodes[node_n]; + if (WSP_GGML_OP_HAS_FINALIZE[node->op]) { + params.nth = node->n_tasks; + wsp_ggml_compute_forward(¶ms, node); + wsp_ggml_graph_compute_perf_stats_node(node, state->shared); + } + } + + // distribute new work or execute it direct if 1T + while (++node_n < cgraph->n_nodes) { + WSP_GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); + + struct wsp_ggml_tensor * node = cgraph->nodes[node_n]; + + state->shared->perf_node_start_cycles = wsp_ggml_perf_cycles(); + state->shared->perf_node_start_time_us = wsp_ggml_perf_time_us(); + + params.nth = node->n_tasks; + + /* INIT */ + if (WSP_GGML_OP_HAS_INIT[node->op]) { + params.type = WSP_GGML_TASK_INIT; + wsp_ggml_compute_forward(¶ms, node); + } + + if (node->n_tasks == 1) { + // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, + // they do something more efficient than spinning (?) + params.type = WSP_GGML_TASK_COMPUTE; + wsp_ggml_compute_forward(¶ms, node); + + if (WSP_GGML_OP_HAS_FINALIZE[node->op]) { + params.type = WSP_GGML_TASK_FINALIZE; + wsp_ggml_compute_forward(¶ms, node); + wsp_ggml_graph_compute_perf_stats_node(node, state->shared); + } + } else { + break; + } + } + + atomic_store(&state->shared->n_active, n_threads); + atomic_store(&state->shared->node_n, node_n); + } else { + // wait for other threads to finish + const int last = node_n; + do { + sched_yield(); + node_n = atomic_load(&state->shared->node_n); + } while (node_n == last); + } + + // check if we should stop + if (node_n >= cgraph->n_nodes) break; + + /* COMPUTE */ + struct wsp_ggml_tensor * node = cgraph->nodes[node_n]; + + struct wsp_ggml_compute_params params = { + /*.type =*/ WSP_GGML_TASK_COMPUTE, + /*.ith =*/ state->ith, + /*.nth =*/ node->n_tasks, + /*.wsize =*/ cgraph->work ? wsp_ggml_nbytes(cgraph->work) : 0, + /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, + }; + + if (state->ith < node->n_tasks) { + wsp_ggml_compute_forward(¶ms, node); + } + } + + return 0; +} + +void wsp_ggml_graph_compute(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph) { + const int n_threads = cgraph->n_threads; + + struct wsp_ggml_compute_state_shared state_shared = { + /*.cgraph =*/ cgraph, + /*.perf_node_start_cycles =*/ 0, + /*.perf_node_start_time_us =*/ 0, + /*.n_threads =*/ n_threads, + /*.n_active =*/ n_threads, + /*.node_n =*/ -1, + }; + struct wsp_ggml_compute_state * workers = alloca(sizeof(struct wsp_ggml_compute_state)*n_threads); + + // initialize tasks + work buffer + { + size_t work_size = 0; + + // thread scheduling for the different operations + for (int i = 0; i < cgraph->n_nodes; i++) { + struct wsp_ggml_tensor * node = cgraph->nodes[i]; + + switch (node->op) { + case WSP_GGML_OP_CPY: + case WSP_GGML_OP_DUP: + { + node->n_tasks = n_threads; + + size_t cur = 0; + if (wsp_ggml_is_quantized(node->type)) { + cur = WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_F32] * node->ne[0] * n_threads; + } + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_ADD: + case WSP_GGML_OP_ADD1: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + if (wsp_ggml_is_quantized(node->src0->type)) { + cur = WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_F32] * node->src0->ne[0] * n_threads; + } + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_ACC: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + if (wsp_ggml_is_quantized(node->src0->type)) { + cur = WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_F32] * node->src1->ne[0] * n_threads; + } + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_SUB: + case WSP_GGML_OP_DIV: + case WSP_GGML_OP_SQR: + case WSP_GGML_OP_SQRT: + case WSP_GGML_OP_LOG: + case WSP_GGML_OP_SUM: + case WSP_GGML_OP_SUM_ROWS: + case WSP_GGML_OP_MEAN: + case WSP_GGML_OP_ARGMAX: + case WSP_GGML_OP_REPEAT: + case WSP_GGML_OP_REPEAT_BACK: + case WSP_GGML_OP_ABS: + case WSP_GGML_OP_SGN: + case WSP_GGML_OP_NEG: + case WSP_GGML_OP_STEP: + case WSP_GGML_OP_TANH: + case WSP_GGML_OP_ELU: + case WSP_GGML_OP_RELU: + { + node->n_tasks = 1; + } break; + case WSP_GGML_OP_MUL: + case WSP_GGML_OP_GELU: + case WSP_GGML_OP_GELU_QUICK: + case WSP_GGML_OP_SILU: + case WSP_GGML_OP_SILU_BACK: + case WSP_GGML_OP_NORM: + case WSP_GGML_OP_RMS_NORM: + case WSP_GGML_OP_RMS_NORM_BACK: + { + node->n_tasks = n_threads; + } break; + case WSP_GGML_OP_MUL_MAT: + case WSP_GGML_OP_OUT_PROD: + { + node->n_tasks = n_threads; + + // TODO: use different scheduling for different matrix sizes + //const int nr0 = wsp_ggml_nrows(node->src0); + //const int nr1 = wsp_ggml_nrows(node->src1); + + //node->n_tasks = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); + + size_t cur = 0; + +#if defined(WSP_GGML_USE_CUBLAS) + if (wsp_ggml_cuda_can_mul_mat(node->src0, node->src1, node)) { + node->n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } + else +#elif defined(WSP_GGML_USE_CLBLAST) + if (wsp_ggml_cl_can_mul_mat(node->src0, node->src1, node)) { + node->n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + cur = wsp_ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node); + } + else +#endif + if (node->src0->type == WSP_GGML_TYPE_F16 && node->src1->type == WSP_GGML_TYPE_F32) { +#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) + if (wsp_ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + // here we need memory just for single 2D matrix from src0 + cur = WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + } else { + cur = WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_F16]*wsp_ggml_nelements(node->src1); + } +#else + cur = WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_F16]*wsp_ggml_nelements(node->src1); +#endif + } else if (node->src0->type == WSP_GGML_TYPE_F32 && node->src1->type == WSP_GGML_TYPE_F32) { + cur = 0; +#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) + if (wsp_ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + } +#endif + } else if (wsp_ggml_is_quantized(node->src0->type) && node->src1->type == WSP_GGML_TYPE_F32) { +#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) + if (wsp_ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + cur = WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + } else +#endif + { + const enum wsp_ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type; + cur = WSP_GGML_TYPE_SIZE[type_q]*wsp_ggml_nelements(node->src1)/WSP_GGML_BLCK_SIZE[type_q]; + } + } else { + WSP_GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_SCALE: + { + node->n_tasks = 1; + } break; + case WSP_GGML_OP_SET: + case WSP_GGML_OP_CONT: + case WSP_GGML_OP_RESHAPE: + case WSP_GGML_OP_VIEW: + case WSP_GGML_OP_PERMUTE: + case WSP_GGML_OP_TRANSPOSE: + case WSP_GGML_OP_GET_ROWS: + case WSP_GGML_OP_GET_ROWS_BACK: + case WSP_GGML_OP_DIAG: + case WSP_GGML_OP_DIAG_MASK_ZERO: + { + node->n_tasks = 1; + } break; + case WSP_GGML_OP_DIAG_MASK_INF: + case WSP_GGML_OP_SOFT_MAX: + case WSP_GGML_OP_SOFT_MAX_BACK: + case WSP_GGML_OP_ROPE: + case WSP_GGML_OP_ROPE_BACK: + { + node->n_tasks = n_threads; + } break; + case WSP_GGML_OP_ALIBI: + { + node->n_tasks = 1; //TODO + } break; + case WSP_GGML_OP_CLAMP: + { + node->n_tasks = 1; //TODO + } break; + case WSP_GGML_OP_CONV_1D: + { + node->n_tasks = n_threads; + + WSP_GGML_ASSERT(node->src0->ne[3] == 1); + WSP_GGML_ASSERT(node->src1->ne[2] == 1); + WSP_GGML_ASSERT(node->src1->ne[3] == 1); + + size_t cur = 0; + const int nk = node->src0->ne[0]; + + if (node->src0->type == WSP_GGML_TYPE_F16 && + node->src1->type == WSP_GGML_TYPE_F32) { + cur = sizeof(wsp_ggml_fp16_t)*( + nk*wsp_ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else if (node->src0->type == WSP_GGML_TYPE_F32 && + node->src1->type == WSP_GGML_TYPE_F32) { + cur = sizeof(float)*( + nk*wsp_ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else { + WSP_GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_CONV_2D: + { + node->n_tasks = n_threads; + + WSP_GGML_ASSERT(node->src1->ne[3] == 1); + + const int64_t ne00 = node->src0->ne[0]; // W + const int64_t ne01 = node->src0->ne[1]; // H + const int64_t ne02 = node->src0->ne[2]; // C + const int64_t ne03 = node->src0->ne[3]; // N + + const int64_t ne10 = node->src1->ne[0]; // W + const int64_t ne11 = node->src1->ne[1]; // H + const int64_t ne12 = node->src1->ne[2]; // C + + const int64_t nk = ne00*ne01; + + UNUSED(ne02); + UNUSED(ne03); + UNUSED(nk); + + size_t cur = 0; + + if (node->src0->type == WSP_GGML_TYPE_F16 && + node->src1->type == WSP_GGML_TYPE_F32) { + cur = sizeof(wsp_ggml_fp16_t)*(ne10*ne11*ne12); + } else if (node->src0->type == WSP_GGML_TYPE_F32 && + node->src1->type == WSP_GGML_TYPE_F32) { + cur = sizeof(float)* (ne10*ne11*ne12); + } else { + WSP_GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_FLASH_ATTN: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + const int64_t ne11 = wsp_ggml_up(node->src1->ne[1], WSP_GGML_SOFT_MAX_UNROLL); + + if (node->src1->type == WSP_GGML_TYPE_F32) { + cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == WSP_GGML_TYPE_F16) { + cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_FLASH_FF: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src1->type == WSP_GGML_TYPE_F32) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == WSP_GGML_TYPE_F16) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_FLASH_ATTN_BACK: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + const int64_t D = node->src0->ne[0]; + const int64_t ne11 = wsp_ggml_up(node->src1->ne[1], WSP_GGML_SOFT_MAX_UNROLL); + const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back + if (node->src1->type == WSP_GGML_TYPE_F32) { + cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == WSP_GGML_TYPE_F16) { + cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_WIN_PART: + case WSP_GGML_OP_WIN_UNPART: + case WSP_GGML_OP_MAP_UNARY: + case WSP_GGML_OP_MAP_BINARY: + case WSP_GGML_OP_MAP_CUSTOM1: + case WSP_GGML_OP_MAP_CUSTOM2: + case WSP_GGML_OP_MAP_CUSTOM3: + { + node->n_tasks = 1; + } break; + case WSP_GGML_OP_CROSS_ENTROPY_LOSS: + { + node->n_tasks = n_threads; + + size_t cur = wsp_ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks); + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + node->n_tasks = n_threads; + + size_t cur = wsp_ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks; + + work_size = MAX(work_size, cur); + } break; + case WSP_GGML_OP_NONE: + { + node->n_tasks = 1; + } break; + case WSP_GGML_OP_COUNT: + { + WSP_GGML_ASSERT(false); + } break; + } + } + + if (cgraph->work != NULL && work_size > cgraph->work_size) { + WSP_GGML_ASSERT(false); // TODO: better handling + } + + if (work_size > 0 && cgraph->work == NULL) { + cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1); + + WSP_GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size); + cgraph->work = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I8, cgraph->work_size); + } + } + + // create thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; ++j) { + workers[j] = (struct wsp_ggml_compute_state) { + .thrd = 0, + .ith = j, + .shared = &state_shared, + }; + + const int rc = wsp_ggml_thread_create(&workers[j].thrd, NULL, wsp_ggml_graph_compute_thread, &workers[j]); + WSP_GGML_ASSERT(rc == 0); + } + } + workers[0].ith = 0; + workers[0].shared = &state_shared; + + const int64_t perf_start_cycles = wsp_ggml_perf_cycles(); + const int64_t perf_start_time_us = wsp_ggml_perf_time_us(); + + // this is a work thread too + wsp_ggml_graph_compute_thread(&workers[0]); + + // don't leave affinity set on the main thread + clear_numa_thread_affinity(); + + // join thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; j++) { + const int rc = wsp_ggml_thread_join(workers[j].thrd, NULL); + WSP_GGML_ASSERT(rc == 0); + } + } + + // performance stats (graph) + { + int64_t perf_cycles_cur = wsp_ggml_perf_cycles() - perf_start_cycles; + int64_t perf_time_us_cur = wsp_ggml_perf_time_us() - perf_start_time_us; + + cgraph->perf_runs++; + cgraph->perf_cycles += perf_cycles_cur; + cgraph->perf_time_us += perf_time_us_cur; + + WSP_GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", + __func__, cgraph->perf_runs, + (double) perf_cycles_cur / (double) wsp_ggml_cycles_per_ms(), + (double) cgraph->perf_cycles / (double) wsp_ggml_cycles_per_ms() / (double) cgraph->perf_runs, + (double) perf_time_us_cur / 1000.0, + (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); + } +} + +void wsp_ggml_graph_reset(struct wsp_ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct wsp_ggml_tensor * grad = cgraph->grads[i]; + + if (grad) { + wsp_ggml_set_zero(grad); + } + } +} + +struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name) { + for (int i = 0; i < cgraph->n_leafs; i++) { + struct wsp_ggml_tensor * leaf = cgraph->leafs[i]; + + if (strcmp(leaf->name, name) == 0) { + return leaf; + } + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct wsp_ggml_tensor * node = cgraph->nodes[i]; + + if (strcmp(node->name, name) == 0) { + return node; + } + } + + return NULL; +} + +static void wsp_ggml_graph_export_leaf(const struct wsp_ggml_tensor * tensor, FILE * fout) { + const int64_t * ne = tensor->ne; + const size_t * nb = tensor->nb; + + fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", + wsp_ggml_type_name(tensor->type), + wsp_ggml_op_name (tensor->op), + tensor->n_dims, + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], + tensor->data, + tensor->name); +} + +static void wsp_ggml_graph_export_node(const struct wsp_ggml_tensor * tensor, const char * arg, FILE * fout) { + const int64_t * ne = tensor->ne; + const size_t * nb = tensor->nb; + + fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n", + arg, + wsp_ggml_type_name(tensor->type), + wsp_ggml_op_name (tensor->op), + tensor->n_dims, + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], + tensor->n_tasks, + tensor->data, + tensor->name); +} + +void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname) { + //assert(cgraph->work == NULL); + //assert(cgraph->work_size == 0); + + uint64_t size_eval = 0; + + // compute size of intermediate results + // TODO: does not take into account scratch buffers !!!! + for (int i = 0; i < cgraph->n_nodes; ++i) { + size_eval += wsp_ggml_nbytes(cgraph->nodes[i]); + } + + // print + { + FILE * fout = stdout; + + fprintf(fout, "\n"); + fprintf(fout, "%-16s %8x\n", "magic", WSP_GGML_FILE_MAGIC); + fprintf(fout, "%-16s %8d\n", "version", WSP_GGML_FILE_VERSION); + fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); + fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); + fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval); + + // header + fprintf(fout, "\n"); + fprintf(fout, "%-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %16s %16s\n", + "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "DATA", "NAME"); + + for (int i = 0; i < cgraph->n_leafs; ++i) { + wsp_ggml_graph_export_leaf(cgraph->leafs[i], fout); + + WSP_GGML_ASSERT(cgraph->leafs[i]->op == WSP_GGML_OP_NONE); + WSP_GGML_ASSERT(cgraph->leafs[i]->src0 == NULL); + WSP_GGML_ASSERT(cgraph->leafs[i]->src1 == NULL); + } + + // header + fprintf(fout, "\n"); + fprintf(fout, "%-6s %-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %8s %16s %16s\n", + "ARG", "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "NTASKS", "DATA", "NAME"); + + for (int i = 0; i < cgraph->n_nodes; ++i) { + wsp_ggml_graph_export_node(cgraph->nodes[i], "DST", fout); + + if (cgraph->nodes[i]->src0) { + wsp_ggml_graph_export_node(cgraph->nodes[i]->src0, "SRC0", fout); + } + + if (cgraph->nodes[i]->src1) { + wsp_ggml_graph_export_node(cgraph->nodes[i]->src1, "SRC1", fout); + } + + for (int j = 0; j < WSP_GGML_MAX_OPT; ++j) { + if (cgraph->nodes[i]->opt[j]) { + wsp_ggml_graph_export_node(cgraph->nodes[i]->opt[j], "OPT", fout); + } + } + + fprintf(fout, "\n"); + } + + fprintf(fout, "\n"); + } + + // write binary data + { + FILE * fout = fopen(fname, "wb"); + + if (!fout) { + fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + return; + } + + // header + { + const uint32_t magic = WSP_GGML_FILE_MAGIC; + const uint32_t version = WSP_GGML_FILE_VERSION; + const uint32_t n_leafs = cgraph->n_leafs; + const uint32_t nodes = cgraph->n_nodes; + + fwrite(&magic, sizeof(uint32_t), 1, fout); + fwrite(&version, sizeof(uint32_t), 1, fout); + fwrite(&n_leafs, sizeof(uint32_t), 1, fout); + fwrite(&nodes, sizeof(uint32_t), 1, fout); + fwrite(&size_eval, sizeof(uint64_t), 1, fout); + } + + // leafs + { + for (int i = 0; i < cgraph->n_leafs; ++i) { + const struct wsp_ggml_tensor * tensor = cgraph->leafs[i]; + + const uint32_t type = tensor->type; + const uint32_t op = tensor->op; + const uint32_t n_dims = tensor->n_dims; + + fwrite(&type, sizeof(uint32_t), 1, fout); + fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&n_dims, sizeof(uint32_t), 1, fout); + + for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) { + const uint64_t ne = tensor->ne[j]; + const uint64_t nb = tensor->nb[j]; + + fwrite(&ne, sizeof(uint64_t), 1, fout); + fwrite(&nb, sizeof(uint64_t), 1, fout); + } + + fwrite(tensor->name, sizeof(char), WSP_GGML_MAX_NAME, fout); + + // dump the data + // TODO: pad this to 32 byte boundary + { + const size_t size = wsp_ggml_nbytes(tensor); + + fwrite(tensor->data, sizeof(char), size, fout); + } + } + } + + // nodes + { + for (int i = 0; i < cgraph->n_nodes; ++i) { + const struct wsp_ggml_tensor * tensor = cgraph->nodes[i]; + + const uint32_t type = tensor->type; + const uint32_t op = tensor->op; + const uint32_t n_dims = tensor->n_dims; + + fwrite(&type, sizeof(uint32_t), 1, fout); + fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&n_dims, sizeof(uint32_t), 1, fout); + + for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) { + const uint64_t ne = tensor->ne[j]; + const uint64_t nb = tensor->nb[j]; + + fwrite(&ne, sizeof(uint64_t), 1, fout); + fwrite(&nb, sizeof(uint64_t), 1, fout); + } + + fwrite(tensor->name, sizeof(char), WSP_GGML_MAX_NAME, fout); + + // output the op arguments + { + struct wsp_ggml_tensor * args[2 + WSP_GGML_MAX_OPT] = { NULL }; + + args[0] = tensor->src0; + args[1] = tensor->src1; + + for (int j = 0; j < WSP_GGML_MAX_OPT; ++j) { + args[2 + j] = tensor->opt[j]; + } + + for (int j = 0; j < 2 + WSP_GGML_MAX_OPT; ++j) { + if (args[j]) { + int32_t idx = -1; + + // check if leaf + { + for (int k = 0; k < cgraph->n_leafs; ++k) { + if (args[j] == cgraph->leafs[k]) { + idx = k; + break; + } + } + } + + // check if node + if (idx == -1) { + for (int k = 0; k < cgraph->n_nodes; ++k) { + if (args[j] == cgraph->nodes[k]) { + idx = WSP_GGML_MAX_NODES + k; + break; + } + } + } + + if (idx == -1) { + fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); + return; + } + + fwrite(&idx, sizeof(int32_t), 1, fout); + } else { + const int32_t nul = -1; + + fwrite(&nul, sizeof(int32_t), 1, fout); + } + } + } + } + } + + fclose(fout); + } +} + +struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval) { + assert(*ctx_data == NULL); + assert(*ctx_eval == NULL); + + struct wsp_ggml_cgraph result = { 0 }; + + struct wsp_ggml_tensor * data = NULL; + + // read file into data + { + FILE * fin = fopen(fname, "rb"); + if (!fin) { + fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + return result; + } + + size_t fsize = 0; + + fseek(fin, 0, SEEK_END); + fsize = ftell(fin); + fseek(fin, 0, SEEK_SET); + + // create the data context + { + const size_t overhead = 1*wsp_ggml_tensor_overhead(); + + struct wsp_ggml_init_params params = { + .mem_size = fsize + overhead, + .mem_buffer = NULL, + .no_alloc = false, + }; + + *ctx_data = wsp_ggml_init(params); + + if (!*ctx_data) { + fprintf(stderr, "%s: failed to create ggml context\n", __func__); + fclose(fin); + return result; + } + } + + data = wsp_ggml_new_tensor_1d(*ctx_data, WSP_GGML_TYPE_I8, fsize); + + { + const size_t ret = fread(data->data, sizeof(char), fsize, fin); + if (ret != fsize) { + fprintf(stderr, "%s: failed to read %s\n", __func__, fname); + fclose(fin); + return result; + } + } + + fclose(fin); + } + + // populate result + { + char * ptr = (char *) data->data; + + const uint32_t magic = *(const uint32_t *) ptr; ptr += sizeof(magic); + + if (magic != WSP_GGML_FILE_MAGIC) { + fprintf(stderr, "%s: invalid magic number, got %08x\n", __func__, magic); + return result; + } + + const uint32_t version = *(const uint32_t *) ptr; ptr += sizeof(version); + + if (version != WSP_GGML_FILE_VERSION) { + fprintf(stderr, "%s: invalid version number\n", __func__); + return result; + } + + const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs); + const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes); + const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval); + + result.n_leafs = n_leafs; + result.n_nodes = n_nodes; + + // create the data context + { + const size_t overhead = (n_leafs + n_nodes)*wsp_ggml_tensor_overhead(); + + struct wsp_ggml_init_params params = { + .mem_size = size_eval + overhead, + .mem_buffer = NULL, + .no_alloc = true, + }; + + *ctx_eval = wsp_ggml_init(params); + + if (!*ctx_eval) { + fprintf(stderr, "%s: failed to create ggml context\n", __func__); + return result; + } + } + + // leafs + { + uint32_t type; + uint32_t op; + uint32_t n_dims; + + for (uint32_t i = 0; i < n_leafs; ++i) { + type = *(const uint32_t *) ptr; ptr += sizeof(type); + op = *(const uint32_t *) ptr; ptr += sizeof(op); + n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); + + int64_t ne[WSP_GGML_MAX_DIMS]; + size_t nb[WSP_GGML_MAX_DIMS]; + + for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) { + uint64_t ne_cur; + uint64_t nb_cur; + + ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); + nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); + + ne[j] = ne_cur; + nb[j] = nb_cur; + } + + struct wsp_ggml_tensor * tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, n_dims, ne); + + tensor->op = (enum wsp_ggml_op) op; + + memcpy(tensor->name, ptr, WSP_GGML_MAX_NAME); ptr += WSP_GGML_MAX_NAME; + + tensor->data = (void *) ptr; + + for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) { + tensor->nb[j] = nb[j]; + } + + result.leafs[i] = tensor; + + ptr += wsp_ggml_nbytes(tensor); + + fprintf(stderr, "%s: loaded leaf %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, wsp_ggml_nbytes(tensor)); + } + } + + wsp_ggml_set_no_alloc(*ctx_eval, false); + + // nodes + { + uint32_t type; + uint32_t op; + uint32_t n_dims; + + for (uint32_t i = 0; i < n_nodes; ++i) { + type = *(const uint32_t *) ptr; ptr += sizeof(type); + op = *(const uint32_t *) ptr; ptr += sizeof(op); + n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); + + enum wsp_ggml_op eop = (enum wsp_ggml_op) op; + + int64_t ne[WSP_GGML_MAX_DIMS]; + size_t nb[WSP_GGML_MAX_DIMS]; + + for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) { + uint64_t ne_cur; + uint64_t nb_cur; + + ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); + nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); + + ne[j] = ne_cur; + nb[j] = nb_cur; + } + + const char * ptr_name = ptr; ptr += WSP_GGML_MAX_NAME; + + const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + WSP_GGML_MAX_OPT)*sizeof(int32_t); + + struct wsp_ggml_tensor * args[2 + WSP_GGML_MAX_OPT] = { NULL }; + + // parse args + for (int j = 0; j < 2 + WSP_GGML_MAX_OPT; ++j) { + const int32_t arg_idx = ptr_arg_idx[j]; + + if (arg_idx == -1) { + continue; + } + + if (arg_idx < WSP_GGML_MAX_NODES) { + args[j] = result.leafs[arg_idx]; + } else { + args[j] = result.nodes[arg_idx - WSP_GGML_MAX_NODES]; + } + } + + // create the tensor + // "view" operations are handled differently + // TODO: handle inplace ops - currently a copy is always made + + struct wsp_ggml_tensor * tensor = NULL; + + switch (eop) { + // TODO: implement other view ops + case WSP_GGML_OP_RESHAPE: + { + tensor = wsp_ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]); + } break; + case WSP_GGML_OP_VIEW: + { + tensor = wsp_ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); + + uint64_t offs; + memcpy(&offs, args[2]->data, sizeof(offs)); + + tensor->data = ((char *) tensor->data) + offs; + } break; + case WSP_GGML_OP_TRANSPOSE: + { + tensor = wsp_ggml_transpose(*ctx_eval, args[0]); + } break; + case WSP_GGML_OP_PERMUTE: + { + tensor = wsp_ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); + } break; + default: + { + tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, n_dims, ne); + + tensor->op = eop; + } break; + } + + memcpy(tensor->name, ptr_name, WSP_GGML_MAX_NAME); + + for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) { + tensor->nb[j] = nb[j]; + } + + tensor->src0 = args[0]; + tensor->src1 = args[1]; + + for (int j = 0; j < WSP_GGML_MAX_OPT; ++j) { + tensor->opt[j] = args[2 + j]; + } + + result.nodes[i] = tensor; + + fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, wsp_ggml_nbytes(tensor)); + } + } + } + + return result; +} + +void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph) { + int64_t perf_total_per_op_us[WSP_GGML_OP_COUNT] = {0}; + + WSP_GGML_PRINT("=== GRAPH ===\n"); + + WSP_GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads); + WSP_GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size); + + WSP_GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + struct wsp_ggml_tensor * node = cgraph->nodes[i]; + + perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us); + + WSP_GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", + i, + node->ne[0], node->ne[1], node->ne[2], + WSP_GGML_OP_NAME[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, + (double) node->perf_cycles / (double) wsp_ggml_cycles_per_ms(), + (double) node->perf_cycles / (double) wsp_ggml_cycles_per_ms() / (double) node->perf_runs, + (double) node->perf_time_us / 1000.0, + (double) node->perf_time_us / 1000.0 / node->perf_runs); + } + + WSP_GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); + for (int i = 0; i < cgraph->n_leafs; i++) { + struct wsp_ggml_tensor * node = cgraph->leafs[i]; + + WSP_GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", + i, + node->ne[0], node->ne[1], + WSP_GGML_OP_NAME[node->op]); + } + + for (int i = 0; i < WSP_GGML_OP_COUNT; i++) { + if (perf_total_per_op_us[i] == 0) { + continue; + } + + WSP_GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", WSP_GGML_OP_NAME[i], (double) perf_total_per_op_us[i] / 1000.0); + } + + WSP_GGML_PRINT("========================================\n"); +} + +// check if node is part of the graph +static bool wsp_ggml_graph_find(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node) { + if (cgraph == NULL) { + return true; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return true; + } + } + + return false; +} + +static struct wsp_ggml_tensor * wsp_ggml_graph_get_parent(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct wsp_ggml_tensor * parent = cgraph->nodes[i]; + + if (parent->grad == node) { + return parent; + } + } + + return NULL; +} + +static void wsp_ggml_graph_dump_dot_node_edge(FILE * fp, const struct wsp_ggml_cgraph * gb, struct wsp_ggml_tensor * node, struct wsp_ggml_tensor * parent, const char * label) { + struct wsp_ggml_tensor * gparent = wsp_ggml_graph_get_parent(gb, node); + struct wsp_ggml_tensor * gparent0 = wsp_ggml_graph_get_parent(gb, parent); + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", + gparent0 ? (void *) gparent0 : (void *) parent, + gparent0 ? "g" : "x", + gparent ? (void *) gparent : (void *) node, + gparent ? "g" : "x", + gparent ? "empty" : "vee", + gparent ? "dashed" : "solid", + label); +} + +static void wsp_ggml_graph_dump_dot_leaf_edge(FILE * fp, struct wsp_ggml_tensor * node, struct wsp_ggml_tensor * parent, const char * label) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", + (void *) parent, "x", + (void *) node, "x", + label); +} + +void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp_ggml_cgraph * gf, const char * filename) { + char color[16]; + + FILE * fp = fopen(filename, "w"); + WSP_GGML_ASSERT(fp); + + fprintf(fp, "digraph G {\n"); + fprintf(fp, " newrank = true;\n"); + fprintf(fp, " rankdir = LR;\n"); + + for (int i = 0; i < gb->n_nodes; i++) { + struct wsp_ggml_tensor * node = gb->nodes[i]; + + if (wsp_ggml_graph_get_parent(gb, node) != NULL) { + continue; + } + + if (node->is_param) { + snprintf(color, sizeof(color), "yellow"); + } else if (node->grad) { + if (wsp_ggml_graph_find(gf, node)) { + snprintf(color, sizeof(color), "green"); + } else { + snprintf(color, sizeof(color), "lightblue"); + } + } else { + snprintf(color, sizeof(color), "white"); + } + + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); + + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, wsp_ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", wsp_ggml_type_name(node->type)); + } + + if (node->n_dims == 2) { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], WSP_GGML_OP_SYMBOL[node->op]); + } else { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], WSP_GGML_OP_SYMBOL[node->op]); + } + + if (node->grad) { + fprintf(fp, " | %s\"; ]\n", WSP_GGML_OP_SYMBOL[node->grad->op]); + } else { + fprintf(fp, "\"; ]\n"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct wsp_ggml_tensor * node = gb->leafs[i]; + + snprintf(color, sizeof(color), "pink"); + + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); + + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, wsp_ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", wsp_ggml_type_name(node->type)); + } + + fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); + if (wsp_ggml_nelements(node) < 5) { + fprintf(fp, " | ("); + for (int j = 0; j < wsp_ggml_nelements(node); j++) { + if (node->type == WSP_GGML_TYPE_I8 || node->type == WSP_GGML_TYPE_I16 || node->type == WSP_GGML_TYPE_I32) { + fprintf(fp, "%d", wsp_ggml_get_i32_1d(node, j)); + } + else if (node->type == WSP_GGML_TYPE_F32 || node->type == WSP_GGML_TYPE_F16) { + fprintf(fp, "%.1e", (double)wsp_ggml_get_f32_1d(node, j)); + } + else { + fprintf(fp, "#"); + } + if (j < wsp_ggml_nelements(node) - 1) { + fprintf(fp, ", "); + } + } + fprintf(fp, ")"); + } + fprintf(fp, "\"; ]\n"); + } + + for (int i = 0; i < gb->n_nodes; i++) { + struct wsp_ggml_tensor * node = gb->nodes[i]; + + if (node->src0) { + wsp_ggml_graph_dump_dot_node_edge(fp, gb, node, node->src0, "x"); + } + + if (node->src1) { + wsp_ggml_graph_dump_dot_node_edge(fp, gb, node, node->src1, "y"); + } + + for (int j = 0; j < WSP_GGML_MAX_OPT; j++) { + if (node->opt[j]) { + char label[16]; + snprintf(label, sizeof(label), "opt %d", j); + wsp_ggml_graph_dump_dot_node_edge(fp, gb, node, node->opt[j], label); + } + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct wsp_ggml_tensor * node = gb->leafs[i]; + + if (node->src0) { + wsp_ggml_graph_dump_dot_leaf_edge(fp, node, node->src0, "x"); + } + + if (node->src1) { + wsp_ggml_graph_dump_dot_leaf_edge(fp, node, node->src1, "y"); + } + + for (int j = 0; j < WSP_GGML_MAX_OPT; j++) { + if (node->opt[j]) { + char label[16]; + snprintf(label, sizeof(label), "opt %d", j); + wsp_ggml_graph_dump_dot_leaf_edge(fp, node, node->opt[j], label); + } + } + } + + fprintf(fp, "}\n"); + + fclose(fp); + + WSP_GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); +} + +//////////////////////////////////////////////////////////////////////////////// + +static void wsp_ggml_opt_set_params(int np, struct wsp_ggml_tensor * const ps[], const float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = wsp_ggml_nelements(ps[p]) ; + // TODO: add function to set tensor from array + for (int64_t j = 0; j < ne; ++j) { + wsp_ggml_set_f32_1d(ps[p], j, x[i++]); + } + } +} + +static void wsp_ggml_opt_get_params(int np, struct wsp_ggml_tensor * const ps[], float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = wsp_ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + x[i++] = wsp_ggml_get_f32_1d(ps[p], j); + } + } +} + +static void wsp_ggml_opt_get_grad(int np, struct wsp_ggml_tensor * const ps[], float * g) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = wsp_ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + g[i++] = wsp_ggml_get_f32_1d(ps[p]->grad, j); + } + } +} + +// +// ADAM +// +// ref: https://arxiv.org/pdf/1412.6980.pdf +// + +static enum wsp_ggml_opt_result wsp_ggml_opt_adam( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_context * opt, + struct wsp_ggml_opt_params params, + struct wsp_ggml_tensor * f, + struct wsp_ggml_cgraph * gf, + struct wsp_ggml_cgraph * gb) { + WSP_GGML_ASSERT(wsp_ggml_is_scalar(f)); + + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; + + // these will store the parameters we want to optimize + struct wsp_ggml_tensor * ps[WSP_GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + WSP_GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + WSP_GGML_ASSERT(np < WSP_GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += wsp_ggml_nelements(gf->nodes[i]); + } + } + + if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) { + int iter = opt->iter; + wsp_ggml_opt_init(opt->ctx, opt, params, nx); + opt->iter = iter; + } + + // constants + const float sched = params.adam.sched; + const float decay = params.adam.decay * sched; + const float alpha = params.adam.alpha * sched; + const float beta1 = params.adam.beta1; + const float beta2 = params.adam.beta2; + const float eps = params.adam.eps; + + float * x = opt->adam.x->data; // view of the parameters + float * g1 = opt->adam.g1->data; // gradient + float * g2 = opt->adam.g2->data; // gradient squared + float * m = opt->adam.m->data; // first moment + float * v = opt->adam.v->data; // second moment + float * mh = opt->adam.mh->data; // first moment hat + float * vh = opt->adam.vh->data; // second moment hat + + float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values + + // update view + wsp_ggml_opt_get_params(np, ps, x); + + // compute the function value + wsp_ggml_graph_reset (gf); + wsp_ggml_set_f32 (f->grad, 1.0f); + wsp_ggml_graph_compute(ctx, gb); + + opt->adam.fx_prev = wsp_ggml_get_f32_1d(f, 0); + opt->adam.fx_best = opt->adam.fx_prev; + if (pf) { + pf[opt->iter % params.past] = opt->adam.fx_prev; + } + + // initialize + if (opt->just_initialized) { + opt->adam.n_no_improvement = 0; + opt->just_initialized = false; + } + + float * fx_best = &opt->adam.fx_best; + float * fx_prev = &opt->adam.fx_prev; + int * n_no_improvement = &opt->adam.n_no_improvement; + + int iter0 = opt->iter; + + // run the optimizer + for (int t = 0; t < params.adam.n_iter; ++t) { + opt->iter = iter0 + t + 1; + WSP_GGML_PRINT_DEBUG ("=== iter %d ===\n", t); + + WSP_GGML_PRINT_DEBUG ("f = %10.6f\n", wsp_ggml_get_f32_1d(f, 0)); + WSP_GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", wsp_ggml_get_f32_1d(ps[0]->grad, 0)); + WSP_GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", wsp_ggml_get_f32_1d(ps[1]->grad, 0)); + + for (int i = 0; i < np; ++i) { + WSP_GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, + wsp_ggml_get_f32_1d(ps[i], 0), wsp_ggml_get_f32_1d(ps[i]->grad, 0)); + } + + const int64_t t_start_wall = wsp_ggml_time_us(); + const int64_t t_start_cpu = wsp_ggml_cycles(); + UNUSED(t_start_wall); + UNUSED(t_start_cpu); + + { + // update the gradient + wsp_ggml_opt_get_grad(np, ps, g1); + + // m_t = beta1*m_t-1 + (1 - beta1)*g_t + wsp_ggml_vec_scale_f32(nx, m, beta1); + wsp_ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1); + + // g2 = g1^2 + wsp_ggml_vec_sqr_f32 (nx, g2, g1); + + // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2 + wsp_ggml_vec_scale_f32(nx, v, beta2); + wsp_ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2); + + // m^hat = m_t / (1 - beta1^t) + // v^hat = v_t / (1 - beta2^t) + // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1) + // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1 + // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps) + // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps) + // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay) + wsp_ggml_vec_cpy_f32 (nx, mh, m); + wsp_ggml_vec_cpy_f32 (nx, vh, v); + + wsp_ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter))); + wsp_ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter))); + + wsp_ggml_vec_sqrt_f32 (nx, vh, vh); + wsp_ggml_vec_acc1_f32 (nx, vh, eps); + + wsp_ggml_vec_div_f32 (nx, mh, mh, vh); + wsp_ggml_vec_scale_f32(nx, x, 1.0f - decay); + wsp_ggml_vec_sub_f32 (nx, x, x, mh); + + // update the parameters + wsp_ggml_opt_set_params(np, ps, x); + } + + wsp_ggml_graph_reset (gf); + wsp_ggml_set_f32 (f->grad, 1.0f); + wsp_ggml_graph_compute(ctx, gb); + + const float fx = wsp_ggml_get_f32_1d(f, 0); + + // check convergence + if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { + WSP_GGML_PRINT_DEBUG("converged\n"); + + return WSP_GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= iter0 + t) { + const float rate = (pf[(iter0 + t)%params.past] - fx)/fx; + + if (fabsf(rate) < params.delta) { + return WSP_GGML_OPT_OK; + } + } + + pf[(iter0 + t)%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx_best[0] > fx) { + fx_best[0] = fx; + n_no_improvement[0] = 0; + } else { + ++n_no_improvement[0]; + + if (n_no_improvement[0] >= params.max_no_improvement) { + return WSP_GGML_OPT_OK; + } + } + } + + fx_prev[0] = fx; + + { + const int64_t t_end_cpu = wsp_ggml_cycles(); + WSP_GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); + UNUSED(t_end_cpu); + + const int64_t t_end_wall = wsp_ggml_time_us(); + WSP_GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); + UNUSED(t_end_wall); + } + } + + return WSP_GGML_OPT_DID_NOT_CONVERGE; +} + +// +// L-BFGS +// +// the L-BFGS implementation below is based on the following implementation: +// +// https://github.com/chokkan/liblbfgs +// + +struct wsp_ggml_lbfgs_iteration_data { + float alpha; + float ys; + float * s; + float * y; +}; + +static enum wsp_ggml_opt_result linesearch_backtracking( + struct wsp_ggml_context * ctx, + const struct wsp_ggml_opt_params * params, + int nx, + float * x, + float * fx, + float * g, + float * d, + float * step, + const float * xp, + struct wsp_ggml_tensor * f, + struct wsp_ggml_cgraph * gf, + struct wsp_ggml_cgraph * gb, + const int np, + struct wsp_ggml_tensor * ps[]) { + int count = 0; + + float width = 0.0f; + float dg = 0.0f; + float finit = 0.0f; + float dginit = 0.0f; + float dgtest = 0.0f; + + const float dec = 0.5f; + const float inc = 2.1f; + + if (*step <= 0.f) { + return WSP_GGML_LINESEARCH_INVALID_PARAMETERS; + } + + // compute the initial gradient in the search direction + wsp_ggml_vec_dot_f32(nx, &dginit, g, d); + + // make sure that d points to a descent direction + if (0 < dginit) { + return WSP_GGML_LINESEARCH_FAIL; + } + + // initialize local variables + finit = *fx; + dgtest = params->lbfgs.ftol*dginit; + + while (true) { + wsp_ggml_vec_cpy_f32(nx, x, xp); + wsp_ggml_vec_mad_f32(nx, x, d, *step); + + // evaluate the function and gradient values + { + wsp_ggml_opt_set_params(np, ps, x); + + wsp_ggml_graph_reset (gf); + wsp_ggml_set_f32 (f->grad, 1.0f); + wsp_ggml_graph_compute(ctx, gb); + + wsp_ggml_opt_get_grad(np, ps, g); + + *fx = wsp_ggml_get_f32_1d(f, 0); + } + + ++count; + + if (*fx > finit + (*step)*dgtest) { + width = dec; + } else { + // Armijo condition is satisfied + if (params->lbfgs.linesearch == WSP_GGML_LINESEARCH_BACKTRACKING_ARMIJO) { + return count; + } + + wsp_ggml_vec_dot_f32(nx, &dg, g, d); + + // check the Wolfe condition + if (dg < params->lbfgs.wolfe * dginit) { + width = inc; + } else { + if(params->lbfgs.linesearch == WSP_GGML_LINESEARCH_BACKTRACKING_WOLFE) { + // regular Wolfe conditions + return count; + } + + if(dg > -params->lbfgs.wolfe*dginit) { + width = dec; + } else { + // strong Wolfe condition (WSP_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) + return count; + } + return count; + } + } + + if (*step < params->lbfgs.min_step) { + return WSP_GGML_LINESEARCH_MINIMUM_STEP; + } + if (*step > params->lbfgs.max_step) { + return WSP_GGML_LINESEARCH_MAXIMUM_STEP; + } + if (params->lbfgs.max_linesearch <= count) { + return WSP_GGML_LINESEARCH_MAXIMUM_ITERATIONS; + } + + (*step) *= width; + } + + return WSP_GGML_LINESEARCH_FAIL; +} + +static enum wsp_ggml_opt_result wsp_ggml_opt_lbfgs( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_context * opt, + struct wsp_ggml_opt_params params, + struct wsp_ggml_tensor * f, + struct wsp_ggml_cgraph * gf, + struct wsp_ggml_cgraph * gb) { + if (params.lbfgs.linesearch == WSP_GGML_LINESEARCH_BACKTRACKING_WOLFE || + params.lbfgs.linesearch == WSP_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { + if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { + return WSP_GGML_OPT_INVALID_WOLFE; + } + } + + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; + + const int m = params.lbfgs.m; + + // these will store the parameters we want to optimize + struct wsp_ggml_tensor * ps[WSP_GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + WSP_GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + WSP_GGML_ASSERT(np < WSP_GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += wsp_ggml_nelements(gf->nodes[i]); + } + } + + if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) { + int iter = opt->iter; + wsp_ggml_opt_init(ctx, opt, params, nx); + opt->iter = iter; + } + + float * x = opt->lbfgs.x->data; // current parameters + float * xp = opt->lbfgs.xp->data; // previous parameters + float * g = opt->lbfgs.g->data; // current gradient + float * gp = opt->lbfgs.gp->data; // previous gradient + float * d = opt->lbfgs.d->data; // search direction + + float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values + + float fx = 0.0f; // cost function value + float xnorm = 0.0f; // ||x|| + float gnorm = 0.0f; // ||g|| + + // initialize x from the graph nodes + wsp_ggml_opt_get_params(np, ps, x); + + // the L-BFGS memory + float * lm_alpha = opt->lbfgs.lmal->data; + float * lm_ys = opt->lbfgs.lmys->data; + float * lm_s = opt->lbfgs.lms->data; + float * lm_y = opt->lbfgs.lmy->data; + + // evaluate the function value and its gradient + { + wsp_ggml_opt_set_params(np, ps, x); + + wsp_ggml_graph_reset (gf); + wsp_ggml_set_f32 (f->grad, 1.0f); + wsp_ggml_graph_compute(ctx, gb); + + wsp_ggml_opt_get_grad(np, ps, g); + + fx = wsp_ggml_get_f32_1d(f, 0); + } + + // search direction = -gradient + wsp_ggml_vec_neg_f32(nx, d, g); + + // ||x||, ||g|| + wsp_ggml_vec_norm_f32(nx, &xnorm, x); + wsp_ggml_vec_norm_f32(nx, &gnorm, g); + + if (xnorm < 1.0f) { + xnorm = 1.0f; + } + + // already optimized + if (gnorm/xnorm <= params.lbfgs.eps) { + return WSP_GGML_OPT_OK; + } + + if (opt->just_initialized) { + if (pf) { + pf[0] = fx; + } + opt->lbfgs.fx_best = fx; + + // initial step + wsp_ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d); + opt->lbfgs.j = 0; + opt->lbfgs.k = 1; + opt->lbfgs.end = 0; + opt->lbfgs.n_no_improvement = 0; + opt->just_initialized = false; + } + + float * fx_best = &opt->lbfgs.fx_best; + float * step = &opt->lbfgs.step; + int * j = &opt->lbfgs.j; + int * k = &opt->lbfgs.k; + int * end = &opt->lbfgs.end; + int * n_no_improvement = &opt->lbfgs.n_no_improvement; + + int ls = 0; + int bound = 0; + + float ys = 0.0f; + float yy = 0.0f; + float beta = 0.0f; + + int it = 0; + + while (true) { + // store the current position and gradient vectors + wsp_ggml_vec_cpy_f32(nx, xp, x); + wsp_ggml_vec_cpy_f32(nx, gp, g); + + ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps); + + if (ls < 0) { + // linesearch failed - go back to the previous point and return + wsp_ggml_vec_cpy_f32(nx, x, xp); + wsp_ggml_vec_cpy_f32(nx, g, gp); + + return ls; + } + + wsp_ggml_vec_norm_f32(nx, &xnorm, x); + wsp_ggml_vec_norm_f32(nx, &gnorm, g); + + WSP_GGML_PRINT_DEBUG("f = %10.6f\n", wsp_ggml_get_f32_1d(f, 0)); + + if (xnorm < 1.0f) { + xnorm = 1.0f; + } + if (gnorm/xnorm <= params.lbfgs.eps) { + // converged + return WSP_GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= k[0]) { + const float rate = (pf[k[0]%params.past] - fx)/fx; + + if (fabsf(rate) < params.delta) { + return WSP_GGML_OPT_OK; + } + } + + pf[k[0]%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx < fx_best[0]) { + fx_best[0] = fx; + n_no_improvement[0] = 0; + } else { + n_no_improvement[0]++; + + if (n_no_improvement[0] >= params.max_no_improvement) { + return WSP_GGML_OPT_OK; + } + } + } + + if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) { + // reached the maximum number of iterations + return WSP_GGML_OPT_DID_NOT_CONVERGE; + } + + // update vectors s and y: + // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. + // y_{k+1} = g_{k+1} - g_{k}. + // + wsp_ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp); + wsp_ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp); + + // compute scalars ys and yy: + // ys = y^t \cdot s -> 1 / \rho. + // yy = y^t \cdot y. + // + wsp_ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]); + wsp_ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]); + + lm_ys[end[0]] = ys; + + // find new search direction + // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS + + bound = (m <= k[0]) ? m : k[0]; + k[0]++; + it++; + end[0] = (end[0] + 1)%m; + + // initialize search direction with -g + wsp_ggml_vec_neg_f32(nx, d, g); + + j[0] = end[0]; + for (int i = 0; i < bound; ++i) { + j[0] = (j[0] + m - 1) % m; + // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} + wsp_ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d); + lm_alpha[j[0]] /= lm_ys[j[0]]; + // q_{i} = q_{i+1} - \alpha_{i} y_{i} + wsp_ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]); + } + + wsp_ggml_vec_scale_f32(nx, d, ys/yy); + + for (int i = 0; i < bound; ++i) { + // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} + wsp_ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d); + beta /= lm_ys[j[0]]; + // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} + wsp_ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta); + j[0] = (j[0] + 1)%m; + } + + step[0] = 1.0; + } + + return WSP_GGML_OPT_DID_NOT_CONVERGE; +} + +struct wsp_ggml_opt_params wsp_ggml_opt_default_params(enum wsp_ggml_opt_type type) { + struct wsp_ggml_opt_params result; + + switch (type) { + case WSP_GGML_OPT_ADAM: + { + result = (struct wsp_ggml_opt_params) { + .type = WSP_GGML_OPT_ADAM, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 100, + + .print_forward_graph = true, + .print_backward_graph = true, + + .adam = { + .n_iter = 10000, + .sched = 1.000f, + .decay = 0.001f, + .alpha = 0.001f, + .beta1 = 0.9f, + .beta2 = 0.999f, + .eps = 1e-8f, + .eps_f = 1e-5f, + .eps_g = 1e-3f, + }, + }; + } break; + case WSP_GGML_OPT_LBFGS: + { + result = (struct wsp_ggml_opt_params) { + .type = WSP_GGML_OPT_LBFGS, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 0, + + .print_forward_graph = true, + .print_backward_graph = true, + + .lbfgs = { + .m = 6, + .n_iter = 100, + .max_linesearch = 20, + + .eps = 1e-5f, + .ftol = 1e-4f, + .wolfe = 0.9f, + .min_step = 1e-20f, + .max_step = 1e+20f, + + .linesearch = WSP_GGML_LINESEARCH_DEFAULT, + }, + }; + } break; + } + + return result; +} + +WSP_GGML_API void wsp_ggml_opt_init( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_context * opt, + struct wsp_ggml_opt_params params, + int64_t nx) { + opt->ctx = ctx; + opt->params = params; + opt->iter = 0; + opt->nx = nx; + opt->just_initialized = true; + switch (opt->params.type) { + case WSP_GGML_OPT_ADAM: + { + opt->adam.x = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.g1 = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.g2 = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.m = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.v = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.mh = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.vh = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.pf = params.past > 0 + ? wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, params.past) + : NULL; + wsp_ggml_set_zero(opt->adam.x); + wsp_ggml_set_zero(opt->adam.g1); + wsp_ggml_set_zero(opt->adam.g2); + wsp_ggml_set_zero(opt->adam.m); + wsp_ggml_set_zero(opt->adam.v); + wsp_ggml_set_zero(opt->adam.mh); + wsp_ggml_set_zero(opt->adam.vh); + if (opt->adam.pf) { + wsp_ggml_set_zero(opt->adam.pf); + } + } break; + case WSP_GGML_OPT_LBFGS: + { + opt->lbfgs.x = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.xp = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.g = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.gp = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.d = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.pf = params.past > 0 + ? wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, params.past) + : NULL; + opt->lbfgs.lmal = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lmys = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lms = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmy = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, nx, params.lbfgs.m); + wsp_ggml_set_zero(opt->lbfgs.x); + wsp_ggml_set_zero(opt->lbfgs.xp); + wsp_ggml_set_zero(opt->lbfgs.g); + wsp_ggml_set_zero(opt->lbfgs.gp); + wsp_ggml_set_zero(opt->lbfgs.d); + if (opt->lbfgs.pf) { + wsp_ggml_set_zero(opt->lbfgs.pf); + } + wsp_ggml_set_zero(opt->lbfgs.lmal); + wsp_ggml_set_zero(opt->lbfgs.lmys); + wsp_ggml_set_zero(opt->lbfgs.lms); + wsp_ggml_set_zero(opt->lbfgs.lmy); + } break; + } +} + +enum wsp_ggml_opt_result wsp_ggml_opt( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_params params, + struct wsp_ggml_tensor * f) { + bool free_ctx = false; + if (ctx == NULL) { + struct wsp_ggml_init_params params_ctx = { + .mem_size = 16*1024*1024, + .mem_buffer = NULL, + .no_alloc = false, + }; + + ctx = wsp_ggml_init(params_ctx); + if (ctx == NULL) { + return WSP_GGML_OPT_NO_CONTEXT; + } + + free_ctx = true; + } + + enum wsp_ggml_opt_result result = WSP_GGML_OPT_OK; + + struct wsp_ggml_opt_context * opt = (struct wsp_ggml_opt_context *) alloca(sizeof(struct wsp_ggml_opt_context)); + + wsp_ggml_opt_init(ctx, opt, params, 0); + result = wsp_ggml_opt_resume(ctx, opt, f); + + if (free_ctx) { + wsp_ggml_free(ctx); + } + + return result; +} + +enum wsp_ggml_opt_result wsp_ggml_opt_resume( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_context * opt, + struct wsp_ggml_tensor * f) { + + // build forward + backward compute graphs + struct wsp_ggml_tensor * gfbuf = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, sizeof(struct wsp_ggml_cgraph) / WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_I32]+ (sizeof(struct wsp_ggml_cgraph) % WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_I32] ? 1 : 0)); + struct wsp_ggml_tensor * gbbuf = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, sizeof(struct wsp_ggml_cgraph) / WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_I32]+ (sizeof(struct wsp_ggml_cgraph) % WSP_GGML_TYPE_SIZE[WSP_GGML_TYPE_I32] ? 1 : 0)); + + struct wsp_ggml_cgraph * gf = (struct wsp_ggml_cgraph *) gfbuf->data; + struct wsp_ggml_cgraph * gb = (struct wsp_ggml_cgraph *) gbbuf->data; + + *gf = wsp_ggml_build_forward (f); + *gb = wsp_ggml_build_backward(ctx, gf, true); + + return wsp_ggml_opt_resume_g(ctx, opt, f, gf, gb); +} + +enum wsp_ggml_opt_result wsp_ggml_opt_resume_g( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_context * opt, + struct wsp_ggml_tensor * f, + struct wsp_ggml_cgraph * gf, + struct wsp_ggml_cgraph * gb) { + + // build forward + backward compute graphs + enum wsp_ggml_opt_result result = WSP_GGML_OPT_OK; + + switch (opt->params.type) { + case WSP_GGML_OPT_ADAM: + { + result = wsp_ggml_opt_adam(ctx, opt, opt->params, f, gf, gb); + } break; + case WSP_GGML_OPT_LBFGS: + { + result = wsp_ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb); + } break; + } + + if (opt->params.print_forward_graph) { + wsp_ggml_graph_print (gf); + wsp_ggml_graph_dump_dot(gf, NULL, "opt-forward.dot"); + } + + if (opt->params.print_backward_graph) { + wsp_ggml_graph_print (gb); + wsp_ggml_graph_dump_dot(gb, gf, "opt-backward.dot"); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK4_0 == 0); + const int nb = k / QK4_0; + + for (int b = 0; b < n; b += k) { + block_q4_0 * restrict y = (block_q4_0 *) dst + b/QK4_0; + + quantize_row_q4_0_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK4_0; j += 2) { + const uint8_t vi0 = y[i].qs[j/2] & 0x0F; + const uint8_t vi1 = y[i].qs[j/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK4_0*sizeof(block_q4_0)); +} + +size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK4_1 == 0); + const int nb = k / QK4_1; + + for (int b = 0; b < n; b += k) { + block_q4_1 * restrict y = (block_q4_1 *) dst + b/QK4_1; + + quantize_row_q4_1_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK4_1; j += 2) { + const uint8_t vi0 = y[i].qs[j/2] & 0x0F; + const uint8_t vi1 = y[i].qs[j/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK4_1*sizeof(block_q4_1)); +} + +size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int b = 0; b < n; b += k) { + block_q5_0 * restrict y = (block_q5_0 *)dst + b/QK5_0; + + quantize_row_q5_0_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int j = 0; j < QK5_0; j += 2) { + const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_0*sizeof(block_q5_0)); +} + +size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; + + for (int b = 0; b < n; b += k) { + block_q5_1 * restrict y = (block_q5_1 *)dst + b/QK5_1; + + quantize_row_q5_1_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int j = 0; j < QK5_1; j += 2) { + const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_1*sizeof(block_q5_1)); +} + +size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int b = 0; b < n; b += k) { + block_q8_0 * restrict y = (block_q8_0 *)dst + b/QK8_0; + + quantize_row_q8_0_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK8_0; ++j) { + const int8_t vi = y[i].qs[j]; + + hist[vi/16 + 8]++; + } + } + } + + return (n/QK8_0*sizeof(block_q8_0)); +} + +size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { + size_t result = 0; + switch (type) { + case WSP_GGML_TYPE_Q4_0: + { + WSP_GGML_ASSERT(start % QK4_0 == 0); + block_q4_0 * block = (block_q4_0*)dst + start / QK4_0; + result = wsp_ggml_quantize_q4_0(src + start, block, n, n, hist); + } break; + case WSP_GGML_TYPE_Q4_1: + { + WSP_GGML_ASSERT(start % QK4_1 == 0); + block_q4_1 * block = (block_q4_1*)dst + start / QK4_1; + result = wsp_ggml_quantize_q4_1(src + start, block, n, n, hist); + } break; + case WSP_GGML_TYPE_Q5_0: + { + WSP_GGML_ASSERT(start % QK5_0 == 0); + block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; + result = wsp_ggml_quantize_q5_0(src + start, block, n, n, hist); + } break; + case WSP_GGML_TYPE_Q5_1: + { + WSP_GGML_ASSERT(start % QK5_1 == 0); + block_q5_1 * block = (block_q5_1*)dst + start / QK5_1; + result = wsp_ggml_quantize_q5_1(src + start, block, n, n, hist); + } break; + case WSP_GGML_TYPE_Q8_0: + { + WSP_GGML_ASSERT(start % QK8_0 == 0); + block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; + result = wsp_ggml_quantize_q8_0(src + start, block, n, n, hist); + } break; +#ifdef WSP_GGML_USE_K_QUANTS + case WSP_GGML_TYPE_Q2_K: + { + WSP_GGML_ASSERT(start % QK_K == 0); + block_q2_K * block = (block_q2_K*)dst + start / QK_K; + result = wsp_ggml_quantize_q2_K(src + start, block, n, n, hist); + } break; + case WSP_GGML_TYPE_Q3_K: + { + WSP_GGML_ASSERT(start % QK_K == 0); + block_q3_K * block = (block_q3_K*)dst + start / QK_K; + result = wsp_ggml_quantize_q3_K(src + start, block, n, n, hist); + } break; + case WSP_GGML_TYPE_Q4_K: + { + WSP_GGML_ASSERT(start % QK_K == 0); + block_q4_K * block = (block_q4_K*)dst + start / QK_K; + result = wsp_ggml_quantize_q4_K(src + start, block, n, n, hist); + } break; + case WSP_GGML_TYPE_Q5_K: + { + WSP_GGML_ASSERT(start % QK_K == 0); + block_q5_K * block = (block_q5_K*)dst + start / QK_K; + result = wsp_ggml_quantize_q5_K(src + start, block, n, n, hist); + } break; + case WSP_GGML_TYPE_Q6_K: + { + WSP_GGML_ASSERT(start % QK_K == 0); + block_q6_K * block = (block_q6_K*)dst + start / QK_K; + result = wsp_ggml_quantize_q6_K(src + start, block, n, n, hist); + } break; +#endif + case WSP_GGML_TYPE_F16: + { + int elemsize = sizeof(wsp_ggml_fp16_t); + wsp_ggml_fp32_to_fp16_row(src + start, (wsp_ggml_fp16_t *)dst + start, n); + result = n * elemsize; + } break; + case WSP_GGML_TYPE_F32: + { + int elemsize = sizeof(float); + result = n * elemsize; + memcpy((uint8_t *)dst + start * elemsize, src + start, result); + } break; + default: + assert(false); + } + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +int wsp_ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_avx2(void) { +#if defined(__AVX2__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_avx512(void) { +#if defined(__AVX512F__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_avx512_vbmi(void) { +#if defined(__AVX512VBMI__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_avx512_vnni(void) { +#if defined(__AVX512VNNI__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_neon(void) { +#if defined(__ARM_NEON) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_arm_fma(void) { +#if defined(__ARM_FEATURE_FMA) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_f16c(void) { +#if defined(__F16C__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_fp16_va(void) { +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_wasm_simd(void) { +#if defined(__wasm_simd128__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_blas(void) { +#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) || defined(WSP_GGML_USE_CUBLAS) || defined(WSP_GGML_USE_CLBLAST) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_cublas(void) { +#if defined(WSP_GGML_USE_CUBLAS) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_clblast(void) { +#if defined(WSP_GGML_USE_CLBLAST) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_gpublas(void) { + return wsp_ggml_cpu_has_cublas() || wsp_ggml_cpu_has_clblast(); +} + +int wsp_ggml_cpu_has_sse3(void) { +#if defined(__SSE3__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_ssse3(void) { +#if defined(__SSSE3__) + return 1; +#else + return 0; +#endif +} + +int wsp_ggml_cpu_has_vsx(void) { +#if defined(__POWER9_VECTOR__) + return 1; +#else + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/ggml.h b/cpp/ggml.h new file mode 100644 index 0000000..2d59f71 --- /dev/null +++ b/cpp/ggml.h @@ -0,0 +1,1541 @@ +#pragma once + +// +// GGML Tensor Library +// +// This documentation is still a work in progress. +// If you wish some specific topics to be covered, feel free to drop a comment: +// +// https://github.com/ggerganov/whisper.cpp/issues/40 +// +// ## Overview +// +// This library implements: +// +// - a set of tensor operations +// - automatic differentiation +// - basic optimization algorithms +// +// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, +// but is not limited to, the following: +// +// - linear regression +// - support vector machines +// - neural networks +// +// The library allows the user to define a certain function using the available tensor operations. This function +// definition is represented internally via a computation graph. Each tensor operation in the function definition +// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the +// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized +// using one of the available optimization algorithms. +// +// For example, here we define the function: f(x) = a*x^2 + b +// +// { +// struct wsp_ggml_init_params params = { +// .mem_size = 16*1024*1024, +// .mem_buffer = NULL, +// }; +// +// // memory allocation happens here +// struct wsp_ggml_context * ctx = wsp_ggml_init(params); +// +// struct wsp_ggml_tensor * x = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 1); +// +// wsp_ggml_set_param(ctx, x); // x is an input variable +// +// struct wsp_ggml_tensor * a = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 1); +// struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 1); +// struct wsp_ggml_tensor * x2 = wsp_ggml_mul(ctx, x, x); +// struct wsp_ggml_tensor * f = wsp_ggml_add(ctx, wsp_ggml_mul(ctx, a, x2), b); +// +// ... +// } +// +// Notice that the function definition above does not involve any actual computation. The computation is performed only +// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: +// +// { +// ... +// +// struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(f); +// +// // set the input variable and parameter values +// wsp_ggml_set_f32(x, 2.0f); +// wsp_ggml_set_f32(a, 3.0f); +// wsp_ggml_set_f32(b, 4.0f); +// +// wsp_ggml_graph_compute(ctx0, &gf); +// +// printf("f = %f\n", wsp_ggml_get_f32_1d(f, 0)); +// +// ... +// } +// +// The actual computation is performed in the wsp_ggml_graph_compute() function. +// +// The wsp_ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the +// wsp_ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know +// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory +// and after defining the computation graph, call the wsp_ggml_used_mem() function to find out how much memory was +// actually needed. +// +// The wsp_ggml_set_param() function marks a tensor as an input variable. This is used by the automatic +// differentiation and optimization algorithms. +// +// The described approach allows to define the function graph once and then compute its forward or backward graphs +// multiple times. All computations will use the same memory buffer allocated in the wsp_ggml_init() function. This way +// the user can avoid the memory allocation overhead at runtime. +// +// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class +// citizens, but in theory the library can be extended to support FP8 and integer data types. +// +// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary +// and binary operations. Most of the available operations fall into one of these two categories. With time, it became +// clear that the library needs to support more complex operations. The way to support these operations is not clear +// yet, but a few examples are demonstrated in the following operations: +// +// - wsp_ggml_permute() +// - wsp_ggml_conv_1d_1s() +// - wsp_ggml_conv_1d_2s() +// +// For each tensor operator, the library implements a forward and backward computation function. The forward function +// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the +// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a +// calculus class, or watch the following video: +// +// What is Automatic Differentiation? +// https://www.youtube.com/watch?v=wG_nF1awSSY +// +// +// ## Tensor data (struct wsp_ggml_tensor) +// +// The tensors are stored in memory via the wsp_ggml_tensor struct. The structure provides information about the size of +// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains +// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: +// +// { +// struct wsp_ggml_tensor * c = wsp_ggml_add(ctx, a, b); +// +// assert(c->src[0] == a); +// assert(c->src[1] == b); +// } +// +// The multi-dimensional tensors are stored in row-major order. The wsp_ggml_tensor struct contains fields for the +// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows +// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and +// permutation. All tensor operations have to take the stride into account and not assume that the tensor is +// contiguous in memory. +// +// The data of the tensor is accessed via the "data" pointer. For example: +// +// { +// struct wsp_ggml_tensor * a = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 2, 3); +// +// // a[1, 2] = 1.0f; +// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f; +// +// // a[2, 0] = 2.0f; +// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f; +// +// ... +// } +// +// Alternatively, there are helper functions, such as wsp_ggml_get_f32_1d() and wsp_ggml_set_f32_1d() that can be used. +// +// ## The matrix multiplication operator (wsp_ggml_mul_mat) +// +// TODO +// +// +// ## Multi-threading +// +// TODO +// +// +// ## Overview of ggml.c +// +// TODO +// +// +// ## SIMD optimizations +// +// TODO +// +// +// ## Debugging ggml +// +// TODO +// +// + +#ifdef WSP_GGML_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef WSP_GGML_BUILD +# define WSP_GGML_API __declspec(dllexport) +# else +# define WSP_GGML_API __declspec(dllimport) +# endif +# else +# define WSP_GGML_API __attribute__ ((visibility ("default"))) +# endif +#else +# define WSP_GGML_API +#endif + +#include +#include +#include + +#define WSP_GGML_FILE_MAGIC 0x67676d6c // "ggml" +#define WSP_GGML_FILE_VERSION 1 + +#define WSP_GGML_QNT_VERSION 2 // bump this on quantization format changes +#define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this + +#define WSP_GGML_MAX_DIMS 4 +#define WSP_GGML_MAX_NODES 4096 +#define WSP_GGML_MAX_PARAMS 256 +#define WSP_GGML_MAX_CONTEXTS 64 +#define WSP_GGML_MAX_OPT 4 +#define WSP_GGML_MAX_NAME 48 +#define WSP_GGML_DEFAULT_N_THREADS 4 + +#define WSP_GGML_UNUSED(x) (void)(x) + +#define WSP_GGML_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +// used to copy the number of elements and stride in bytes of tensors into local variables. +// main purpose is to reduce code duplication and improve readability. +// +// example: +// +// WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); +// WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); +// +#define WSP_GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ + const type prefix##0 = (pointer)->array[0]; \ + WSP_GGML_UNUSED(prefix##0); +#define WSP_GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ + WSP_GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ + const type prefix##1 = (pointer)->array[1]; \ + WSP_GGML_UNUSED(prefix##1); +#define WSP_GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ + WSP_GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ + const type prefix##2 = (pointer)->array[2]; \ + WSP_GGML_UNUSED(prefix##2); +#define WSP_GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ + WSP_GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ + const type prefix##3 = (pointer)->array[3]; \ + WSP_GGML_UNUSED(prefix##3); + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __ARM_NEON + // we use the built-in 16-bit float type + typedef __fp16 wsp_ggml_fp16_t; +#else + typedef uint16_t wsp_ggml_fp16_t; +#endif + + // convert FP16 <-> FP32 + WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t x); + WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float x); + + WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t * x, float * y, size_t n); + WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float * x, wsp_ggml_fp16_t * y, size_t n); + + struct wsp_ggml_object; + struct wsp_ggml_context; + + enum wsp_ggml_type { + WSP_GGML_TYPE_F32 = 0, + WSP_GGML_TYPE_F16 = 1, + WSP_GGML_TYPE_Q4_0 = 2, + WSP_GGML_TYPE_Q4_1 = 3, + // WSP_GGML_TYPE_Q4_2 = 4, support has been removed + // WSP_GGML_TYPE_Q4_3 (5) support has been removed + WSP_GGML_TYPE_Q5_0 = 6, + WSP_GGML_TYPE_Q5_1 = 7, + WSP_GGML_TYPE_Q8_0 = 8, + WSP_GGML_TYPE_Q8_1 = 9, + // k-quantizations + WSP_GGML_TYPE_Q2_K = 10, + WSP_GGML_TYPE_Q3_K = 11, + WSP_GGML_TYPE_Q4_K = 12, + WSP_GGML_TYPE_Q5_K = 13, + WSP_GGML_TYPE_Q6_K = 14, + WSP_GGML_TYPE_Q8_K = 15, + WSP_GGML_TYPE_I8, + WSP_GGML_TYPE_I16, + WSP_GGML_TYPE_I32, + WSP_GGML_TYPE_COUNT, + }; + + enum wsp_ggml_backend { + WSP_GGML_BACKEND_CPU = 0, + WSP_GGML_BACKEND_GPU = 10, + WSP_GGML_BACKEND_GPU_SPLIT = 20, + }; + + // model file types + enum wsp_ggml_ftype { + WSP_GGML_FTYPE_UNKNOWN = -1, + WSP_GGML_FTYPE_ALL_F32 = 0, + WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors + }; + + // available tensor operations: + enum wsp_ggml_op { + WSP_GGML_OP_NONE = 0, + + WSP_GGML_OP_DUP, + WSP_GGML_OP_ADD, + WSP_GGML_OP_ADD1, + WSP_GGML_OP_ACC, + WSP_GGML_OP_SUB, + WSP_GGML_OP_MUL, + WSP_GGML_OP_DIV, + WSP_GGML_OP_SQR, + WSP_GGML_OP_SQRT, + WSP_GGML_OP_LOG, + WSP_GGML_OP_SUM, + WSP_GGML_OP_SUM_ROWS, + WSP_GGML_OP_MEAN, + WSP_GGML_OP_ARGMAX, + WSP_GGML_OP_REPEAT, + WSP_GGML_OP_REPEAT_BACK, + WSP_GGML_OP_ABS, + WSP_GGML_OP_SGN, + WSP_GGML_OP_NEG, + WSP_GGML_OP_STEP, + WSP_GGML_OP_TANH, + WSP_GGML_OP_ELU, + WSP_GGML_OP_RELU, + WSP_GGML_OP_GELU, + WSP_GGML_OP_GELU_QUICK, + WSP_GGML_OP_SILU, + WSP_GGML_OP_SILU_BACK, + WSP_GGML_OP_NORM, // normalize + WSP_GGML_OP_RMS_NORM, + WSP_GGML_OP_RMS_NORM_BACK, + + WSP_GGML_OP_MUL_MAT, + WSP_GGML_OP_OUT_PROD, + + WSP_GGML_OP_SCALE, + WSP_GGML_OP_SET, + WSP_GGML_OP_CPY, + WSP_GGML_OP_CONT, + WSP_GGML_OP_RESHAPE, + WSP_GGML_OP_VIEW, + WSP_GGML_OP_PERMUTE, + WSP_GGML_OP_TRANSPOSE, + WSP_GGML_OP_GET_ROWS, + WSP_GGML_OP_GET_ROWS_BACK, + WSP_GGML_OP_DIAG, + WSP_GGML_OP_DIAG_MASK_INF, + WSP_GGML_OP_DIAG_MASK_ZERO, + WSP_GGML_OP_SOFT_MAX, + WSP_GGML_OP_SOFT_MAX_BACK, + WSP_GGML_OP_ROPE, + WSP_GGML_OP_ROPE_BACK, + WSP_GGML_OP_ALIBI, + WSP_GGML_OP_CLAMP, + WSP_GGML_OP_CONV_1D, + WSP_GGML_OP_CONV_2D, + + WSP_GGML_OP_FLASH_ATTN, + WSP_GGML_OP_FLASH_FF, + WSP_GGML_OP_FLASH_ATTN_BACK, + WSP_GGML_OP_WIN_PART, + WSP_GGML_OP_WIN_UNPART, + + WSP_GGML_OP_MAP_UNARY, + WSP_GGML_OP_MAP_BINARY, + + WSP_GGML_OP_MAP_CUSTOM1, + WSP_GGML_OP_MAP_CUSTOM2, + WSP_GGML_OP_MAP_CUSTOM3, + + WSP_GGML_OP_CROSS_ENTROPY_LOSS, + WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK, + + WSP_GGML_OP_COUNT, + }; + + + // ggml object + struct wsp_ggml_object { + size_t offs; + size_t size; + + struct wsp_ggml_object * next; + + char padding[8]; + }; + + static const size_t WSP_GGML_OBJECT_SIZE = sizeof(struct wsp_ggml_object); + + // n-dimensional tensor + struct wsp_ggml_tensor { + enum wsp_ggml_type type; + enum wsp_ggml_backend backend; + + int n_dims; + int64_t ne[WSP_GGML_MAX_DIMS]; // number of elements + size_t nb[WSP_GGML_MAX_DIMS]; // stride in bytes: + // nb[0] = sizeof(type) + // nb[1] = nb[0] * ne[0] + padding + // nb[i] = nb[i-1] * ne[i-1] + + // compute data + enum wsp_ggml_op op; + + bool is_param; + + struct wsp_ggml_tensor * grad; + struct wsp_ggml_tensor * src0; + struct wsp_ggml_tensor * src1; + struct wsp_ggml_tensor * opt[WSP_GGML_MAX_OPT]; + + // thread scheduling + int n_tasks; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + + void * data; + + char name[WSP_GGML_MAX_NAME]; + + void * extra; // extra things e.g. for ggml-cuda.cu + + char padding[4]; + }; + + static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor); + + // computation graph + struct wsp_ggml_cgraph { + int n_nodes; + int n_leafs; + int n_threads; + + size_t work_size; + struct wsp_ggml_tensor * work; + + struct wsp_ggml_tensor * nodes[WSP_GGML_MAX_NODES]; + struct wsp_ggml_tensor * grads[WSP_GGML_MAX_NODES]; + struct wsp_ggml_tensor * leafs[WSP_GGML_MAX_NODES]; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + }; + + // scratch buffer + struct wsp_ggml_scratch { + size_t offs; + size_t size; + void * data; + }; + + struct wsp_ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally + bool no_alloc; // don't allocate memory for the tensor data + }; + + + // compute types + + // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. + // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. + enum wsp_ggml_task_type { + WSP_GGML_TASK_INIT = 0, + WSP_GGML_TASK_COMPUTE, + WSP_GGML_TASK_FINALIZE, + }; + + struct wsp_ggml_compute_params { + enum wsp_ggml_task_type type; + + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + }; + + // misc + + WSP_GGML_API void wsp_ggml_time_init(void); // call this once at the beginning of the program + WSP_GGML_API int64_t wsp_ggml_time_ms(void); + WSP_GGML_API int64_t wsp_ggml_time_us(void); + WSP_GGML_API int64_t wsp_ggml_cycles(void); + WSP_GGML_API int64_t wsp_ggml_cycles_per_ms(void); + + WSP_GGML_API void wsp_ggml_numa_init(void); // call once for better performance on NUMA systems + WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node + + WSP_GGML_API void wsp_ggml_print_object (const struct wsp_ggml_object * obj); + WSP_GGML_API void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx); + + WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API size_t wsp_ggml_nbytes_split(const struct wsp_ggml_tensor * tensor, int nrows_split); + + WSP_GGML_API int wsp_ggml_blck_size (enum wsp_ggml_type type); + WSP_GGML_API size_t wsp_ggml_type_size (enum wsp_ggml_type type); // size in bytes for all elements in a block + WSP_GGML_API float wsp_ggml_type_sizef(enum wsp_ggml_type type); // wsp_ggml_type_size()/wsp_ggml_blck_size() as float + + WSP_GGML_API const char * wsp_ggml_type_name(enum wsp_ggml_type type); + WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op); + + WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor); + + WSP_GGML_API bool wsp_ggml_is_quantized(enum wsp_ggml_type type); + + // TODO: temporary until model loading of ggml examples is refactored + WSP_GGML_API enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype); + + WSP_GGML_API bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor); + WSP_GGML_API bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor); + WSP_GGML_API bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor); + + // use this to compute the memory overhead of a tensor + WSP_GGML_API size_t wsp_ggml_tensor_overhead(void); + + // main + + WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params); + WSP_GGML_API void wsp_ggml_free(struct wsp_ggml_context * ctx); + + WSP_GGML_API size_t wsp_ggml_used_mem(const struct wsp_ggml_context * ctx); + + WSP_GGML_API size_t wsp_ggml_set_scratch (struct wsp_ggml_context * ctx, struct wsp_ggml_scratch scratch); + WSP_GGML_API void wsp_ggml_set_no_alloc(struct wsp_ggml_context * ctx, bool no_alloc); + + WSP_GGML_API void * wsp_ggml_get_mem_buffer (const struct wsp_ggml_context * ctx); + WSP_GGML_API size_t wsp_ggml_get_mem_size (const struct wsp_ggml_context * ctx); + WSP_GGML_API size_t wsp_ggml_get_max_tensor_size(const struct wsp_ggml_context * ctx); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_tensor( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int n_dims, + const int64_t *ne); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_tensor_1d( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int64_t ne0); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_tensor_2d( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int64_t ne0, + int64_t ne1); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_tensor_3d( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_tensor_4d( + struct wsp_ggml_context * ctx, + enum wsp_ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_i32(struct wsp_ggml_context * ctx, int32_t value); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float value); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_tensor (struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_f32 (struct wsp_ggml_tensor * tensor, float value); + + WSP_GGML_API int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i); + WSP_GGML_API void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value); + + WSP_GGML_API float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i); + WSP_GGML_API void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value); + + WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor); + + WSP_GGML_API const char * wsp_ggml_get_name(const struct wsp_ggml_tensor * tensor); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_name(struct wsp_ggml_tensor * tensor, const char * name); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_format_name(struct wsp_ggml_tensor * tensor, const char * fmt, ...); + + // + // operations on tensors with backpropagation + // + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add1( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add1_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_acc( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_acc_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sub( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sub_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_div( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_div_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sqr( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sqr_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sqrt( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sqrt_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_log( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_log_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // return scalar + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sum( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d] + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sum_rows( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // mean along rows + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mean( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // argmax along rows + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_argmax( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // if a is the same shape as b, and a is not parameter, return a + // otherwise, return a new tensor: repeat(a) to fit in b + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_abs( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_abs_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sgn( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sgn_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_neg( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_neg_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_step( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_step_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_tanh( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_tanh_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_elu( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_elu_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_relu( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_relu_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // TODO: double-check this computation is correct + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu_quick( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu_quick_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_silu( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_silu_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // a - x + // b - dy + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_silu_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // normalize along rows + // TODO: eps is hardcoded to 1e-5 for now + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rms_norm( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rms_norm_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // a - x + // b - dy + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rms_norm_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // A: n columns, m rows + // B: n columns, p rows (i.e. we transpose it internally) + // result is m columns, p rows + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // A: m columns, n rows, + // B: p columns, n rows, + // result is m columns, p rows + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_out_prod( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // + // operations on tensors without backpropagation + // + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // in-place, returns view(a) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // b -> view(a,offset,nb1,nb2,3), return modified a + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return view(a) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t offset); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return modified a + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return view(a) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + size_t nb1, + size_t offset); + + + // a -> b, return view(b) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // make contiguous + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // return view(a), b specifies the new shape + // TODO: when we start computing gradient, make a copy instead of view + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reshape( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reshape_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reshape_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reshape_3d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reshape_4d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // offset in bytes + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + size_t offset); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, // row stride in bytes + size_t offset); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_3d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t offset); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_4d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t nb3, + size_t offset); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_permute( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3); + + // alias for wsp_ggml_permute(ctx, a, 1, 0, 2, 3) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_transpose( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * c); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // set elements above the diagonal to -INF + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag_mask_inf( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag_mask_inf_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past); + + // set elements above the diagonal to 0 + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag_mask_zero( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag_mask_zero_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + // in-place, returns view(a) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // in-place, returns view(a) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + // rotary position embedding + // if mode & 1 == 1, skip n_past elements + // if mode & 2 == 1, GPT-NeoX style + // if mode & 4 == 1, ChatGLM style + // TODO: avoid creating a new tensor every time + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx); + + // in-place, returns view(a) + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_inplace( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx); + + // rotary position embedding backward, i.e compute dx from dy + // a - dy + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + int n_dims, + int mode); + + // alibi position embedding + // in-place, returns view(a) + struct wsp_ggml_tensor * wsp_ggml_alibi( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int n_past, + int n_head, + float bias_max); + + // clamp + // in-place, returns view(a) + struct wsp_ggml_tensor * wsp_ggml_clamp( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + float min, + float max); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, // stride + int p0, // padding + int d0); // dilation + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1); + + // conv_1d with padding = half + // alias for wsp_ggml_conv_1d(a, b, s, a->ne[0]/2, d) + WSP_GGML_API struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s, + int d); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * q, + struct wsp_ggml_tensor * k, + struct wsp_ggml_tensor * v, + bool masked); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * q, + struct wsp_ggml_tensor * k, + struct wsp_ggml_tensor * v, + struct wsp_ggml_tensor * d, + bool masked); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_ff( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b0, + struct wsp_ggml_tensor * b1, + struct wsp_ggml_tensor * c0, + struct wsp_ggml_tensor * c1); + + // partition into non-overlapping windows with padding if needed + // example: + // a: 768 64 64 1 + // w: 14 + // res: 768 14 14 25 + // used in sam + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_win_part( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int w); + + // reverse of wsp_ggml_win_part + // used in sam + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_win_unpart( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int w0, + int h0, + int w); + + // custom operators + + typedef void (*wsp_ggml_unary_op_f32_t) (const int, float *, const float *); + typedef void (*wsp_ggml_binary_op_f32_t)(const int, float *, const float *, const float *); + + typedef void (*wsp_ggml_custom1_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *); + typedef void (*wsp_ggml_custom2_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *); + typedef void (*wsp_ggml_custom3_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + wsp_ggml_unary_op_f32_t fun); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + wsp_ggml_unary_op_f32_t fun); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + wsp_ggml_binary_op_f32_t fun); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + wsp_ggml_binary_op_f32_t fun); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + wsp_ggml_custom1_op_f32_t fun); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + wsp_ggml_custom1_op_f32_t fun); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + wsp_ggml_custom2_op_f32_t fun); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + wsp_ggml_custom2_op_f32_t fun); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * c, + wsp_ggml_custom3_op_f32_t fun); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_inplace_f32( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * c, + wsp_ggml_custom3_op_f32_t fun); + + // loss function + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss_back( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + struct wsp_ggml_tensor * c); + + // + // automatic differentiation + // + + WSP_GGML_API void wsp_ggml_set_param( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * tensor); + + WSP_GGML_API void wsp_ggml_build_forward_expand(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor); + + WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_forward (struct wsp_ggml_tensor * tensor); + WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, bool keep); + + WSP_GGML_API void wsp_ggml_graph_compute(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph); + WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name); + + WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname); + WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval); + + // print info and performance information for the graph + WSP_GGML_API void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph); + + // dump the graph into a file using the dot format + WSP_GGML_API void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp_ggml_cgraph * gf, const char * filename); + + // + // optimization + // + + // optimization methods + enum wsp_ggml_opt_type { + WSP_GGML_OPT_ADAM, + WSP_GGML_OPT_LBFGS, + }; + + // linesearch methods + enum wsp_ggml_linesearch { + WSP_GGML_LINESEARCH_DEFAULT = 1, + + WSP_GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, + WSP_GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, + WSP_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, + }; + + // optimization return values + enum wsp_ggml_opt_result { + WSP_GGML_OPT_OK = 0, + WSP_GGML_OPT_DID_NOT_CONVERGE, + WSP_GGML_OPT_NO_CONTEXT, + WSP_GGML_OPT_INVALID_WOLFE, + WSP_GGML_OPT_FAIL, + + WSP_GGML_LINESEARCH_FAIL = -128, + WSP_GGML_LINESEARCH_MINIMUM_STEP, + WSP_GGML_LINESEARCH_MAXIMUM_STEP, + WSP_GGML_LINESEARCH_MAXIMUM_ITERATIONS, + WSP_GGML_LINESEARCH_INVALID_PARAMETERS, + }; + + // optimization parameters + // + // see ggml.c (wsp_ggml_opt_default_params) for default values + // + struct wsp_ggml_opt_params { + enum wsp_ggml_opt_type type; + + int n_threads; + + // delta-based convergence test + // + // if past == 0 - disabled + // if past > 0: + // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) + // + int past; + float delta; + + // maximum number of iterations without improvement + // + // if 0 - disabled + // if > 0: + // assume convergence if no cost improvement in this number of iterations + // + int max_no_improvement; + + bool print_forward_graph; + bool print_backward_graph; + + // ADAM parameters + struct { + int n_iter; + + float sched; // schedule multiplier (fixed, decay or warmup) + float decay; // weight decay for AdamW, use 0.0f to disable + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float eps_f; // epsilon for convergence test + float eps_g; // epsilon for convergence test + } adam; + + // LBFGS parameters + struct { + int m; // number of corrections to approximate the inv. Hessian + int n_iter; + int max_linesearch; + + float eps; // convergence tolerance + float ftol; // line search tolerance + float wolfe; + float min_step; + float max_step; + + enum wsp_ggml_linesearch linesearch; + } lbfgs; + }; + + struct wsp_ggml_opt_context { + struct wsp_ggml_context * ctx; + struct wsp_ggml_opt_params params; + + int iter; + int64_t nx; // number of parameter elements + + bool just_initialized; + + struct { + struct wsp_ggml_tensor * x; // view of the parameters + struct wsp_ggml_tensor * g1; // gradient + struct wsp_ggml_tensor * g2; // gradient squared + struct wsp_ggml_tensor * m; // first moment + struct wsp_ggml_tensor * v; // second moment + struct wsp_ggml_tensor * mh; // first moment hat + struct wsp_ggml_tensor * vh; // second moment hat + struct wsp_ggml_tensor * pf; // past function values + float fx_best; + float fx_prev; + int n_no_improvement; + } adam; + + struct { + struct wsp_ggml_tensor * x; // current parameters + struct wsp_ggml_tensor * xp; // previous parameters + struct wsp_ggml_tensor * g; // current gradient + struct wsp_ggml_tensor * gp; // previous gradient + struct wsp_ggml_tensor * d; // search direction + struct wsp_ggml_tensor * pf; // past function values + struct wsp_ggml_tensor * lmal; // the L-BFGS memory alpha + struct wsp_ggml_tensor * lmys; // the L-BFGS memory ys + struct wsp_ggml_tensor * lms; // the L-BFGS memory s + struct wsp_ggml_tensor * lmy; // the L-BFGS memory y + float fx_best; + float step; + int j; + int k; + int end; + int n_no_improvement; + } lbfgs; + }; + + WSP_GGML_API struct wsp_ggml_opt_params wsp_ggml_opt_default_params(enum wsp_ggml_opt_type type); + + // optimize the function defined by the tensor f + WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_params params, + struct wsp_ggml_tensor * f); + + // initialize optimizer context + WSP_GGML_API void wsp_ggml_opt_init( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_context * opt, + struct wsp_ggml_opt_params params, + int64_t nx); + + // continue optimizing the function defined by the tensor f + WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_context * opt, + struct wsp_ggml_tensor * f); + + // continue optimizing the function defined by the tensor f + WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume_g( + struct wsp_ggml_context * ctx, + struct wsp_ggml_opt_context * opt, + struct wsp_ggml_tensor * f, + struct wsp_ggml_cgraph * gf, + struct wsp_ggml_cgraph * gb); + + // + // quantization + // + + WSP_GGML_API size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); + + WSP_GGML_API size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + + // + // system info + // + + WSP_GGML_API int wsp_ggml_cpu_has_avx (void); + WSP_GGML_API int wsp_ggml_cpu_has_avx2 (void); + WSP_GGML_API int wsp_ggml_cpu_has_avx512 (void); + WSP_GGML_API int wsp_ggml_cpu_has_avx512_vbmi(void); + WSP_GGML_API int wsp_ggml_cpu_has_avx512_vnni(void); + WSP_GGML_API int wsp_ggml_cpu_has_fma (void); + WSP_GGML_API int wsp_ggml_cpu_has_neon (void); + WSP_GGML_API int wsp_ggml_cpu_has_arm_fma (void); + WSP_GGML_API int wsp_ggml_cpu_has_f16c (void); + WSP_GGML_API int wsp_ggml_cpu_has_fp16_va (void); + WSP_GGML_API int wsp_ggml_cpu_has_wasm_simd (void); + WSP_GGML_API int wsp_ggml_cpu_has_blas (void); + WSP_GGML_API int wsp_ggml_cpu_has_cublas (void); + WSP_GGML_API int wsp_ggml_cpu_has_clblast (void); + WSP_GGML_API int wsp_ggml_cpu_has_gpublas (void); + WSP_GGML_API int wsp_ggml_cpu_has_sse3 (void); + WSP_GGML_API int wsp_ggml_cpu_has_ssse3 (void); + WSP_GGML_API int wsp_ggml_cpu_has_vsx (void); + + // + // Internal types and functions exposed for tests and benchmarks + // + +#ifdef __cplusplus + // restrict not standard in C++ +#define WSP_GGML_RESTRICT +#else +#define WSP_GGML_RESTRICT restrict +#endif + typedef void (*dequantize_row_q_t)(const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int k); + typedef void (*quantize_row_q_t) (const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int k); + typedef void (*vec_dot_q_t) (const int n, float * WSP_GGML_RESTRICT s, const void * WSP_GGML_RESTRICT x, const void * WSP_GGML_RESTRICT y); + + typedef struct { + dequantize_row_q_t dequantize_row_q; + quantize_row_q_t quantize_row_q; + quantize_row_q_t quantize_row_q_reference; + quantize_row_q_t quantize_row_q_dot; + vec_dot_q_t vec_dot_q; + enum wsp_ggml_type vec_dot_type; + } quantize_fns_t; + + quantize_fns_t wsp_ggml_internal_get_quantize_fn(size_t i); + +#ifdef __cplusplus +} +#endif diff --git a/cpp/whisper.cpp b/cpp/whisper.cpp new file mode 100644 index 0000000..be83206 --- /dev/null +++ b/cpp/whisper.cpp @@ -0,0 +1,5512 @@ +#include "whisper.h" +#ifdef WHISPER_USE_COREML +#include "coreml/whisper-encoder.h" +#endif + +#if WHISPER_USE_OPENVINO +#include "openvino/whisper-openvino-encoder.h" +#endif + +#include "ggml.h" + +#include +#include +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#if defined(WSP_GGML_BIG_ENDIAN) +#include + +template +static T byteswap(T value) { + return std::byteswap(value); +} + +template<> +float byteswap(float value) { + return std::bit_cast(byteswap(std::bit_cast(value))); +} + +template +static void byteswap_tensor_data(wsp_ggml_tensor * tensor) { + T * datum = reinterpret_cast(tensor->data); + for (int i = 0; i < wsp_ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(wsp_ggml_tensor * tensor) { + switch (tensor->type) { + case WSP_GGML_TYPE_I16: { + byteswap_tensor_data(tensor); + break; + } + case WSP_GGML_TYPE_F16: { + byteswap_tensor_data(tensor); + break; + } + case WSP_GGML_TYPE_I32: { + byteswap_tensor_data(tensor); + break; + } + case WSP_GGML_TYPE_F32: { + byteswap_tensor_data(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(t); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#define WHISPER_ASSERT(x) \ + do { \ + if (!(x)) { \ + log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define WHISPER_DEBUG + +#if defined(WHISPER_DEBUG) +#define WHISPER_PRINT_DEBUG(...) \ + do { \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) +#else +#define WHISPER_PRINT_DEBUG(...) +#endif + +//#define WHISPER_USE_FLASH_ATTN +//#define WHISPER_USE_FLASH_FF +#define WHISPER_MAX_DECODERS 16 + +#define WHISPER_USE_SCRATCH +#define WHISPER_MAX_SCRATCH_BUFFERS 16 + +// available whisper models +enum e_model { + MODEL_UNKNOWN, + MODEL_TINY, + MODEL_BASE, + MODEL_SMALL, + MODEL_MEDIUM, + MODEL_LARGE, +}; + +static const std::map> g_lang = { + { "en", { 0, "english", } }, + { "zh", { 1, "chinese", } }, + { "de", { 2, "german", } }, + { "es", { 3, "spanish", } }, + { "ru", { 4, "russian", } }, + { "ko", { 5, "korean", } }, + { "fr", { 6, "french", } }, + { "ja", { 7, "japanese", } }, + { "pt", { 8, "portuguese", } }, + { "tr", { 9, "turkish", } }, + { "pl", { 10, "polish", } }, + { "ca", { 11, "catalan", } }, + { "nl", { 12, "dutch", } }, + { "ar", { 13, "arabic", } }, + { "sv", { 14, "swedish", } }, + { "it", { 15, "italian", } }, + { "id", { 16, "indonesian", } }, + { "hi", { 17, "hindi", } }, + { "fi", { 18, "finnish", } }, + { "vi", { 19, "vietnamese", } }, + { "he", { 20, "hebrew", } }, + { "uk", { 21, "ukrainian", } }, + { "el", { 22, "greek", } }, + { "ms", { 23, "malay", } }, + { "cs", { 24, "czech", } }, + { "ro", { 25, "romanian", } }, + { "da", { 26, "danish", } }, + { "hu", { 27, "hungarian", } }, + { "ta", { 28, "tamil", } }, + { "no", { 29, "norwegian", } }, + { "th", { 30, "thai", } }, + { "ur", { 31, "urdu", } }, + { "hr", { 32, "croatian", } }, + { "bg", { 33, "bulgarian", } }, + { "lt", { 34, "lithuanian", } }, + { "la", { 35, "latin", } }, + { "mi", { 36, "maori", } }, + { "ml", { 37, "malayalam", } }, + { "cy", { 38, "welsh", } }, + { "sk", { 39, "slovak", } }, + { "te", { 40, "telugu", } }, + { "fa", { 41, "persian", } }, + { "lv", { 42, "latvian", } }, + { "bn", { 43, "bengali", } }, + { "sr", { 44, "serbian", } }, + { "az", { 45, "azerbaijani", } }, + { "sl", { 46, "slovenian", } }, + { "kn", { 47, "kannada", } }, + { "et", { 48, "estonian", } }, + { "mk", { 49, "macedonian", } }, + { "br", { 50, "breton", } }, + { "eu", { 51, "basque", } }, + { "is", { 52, "icelandic", } }, + { "hy", { 53, "armenian", } }, + { "ne", { 54, "nepali", } }, + { "mn", { 55, "mongolian", } }, + { "bs", { 56, "bosnian", } }, + { "kk", { 57, "kazakh", } }, + { "sq", { 58, "albanian", } }, + { "sw", { 59, "swahili", } }, + { "gl", { 60, "galician", } }, + { "mr", { 61, "marathi", } }, + { "pa", { 62, "punjabi", } }, + { "si", { 63, "sinhala", } }, + { "km", { 64, "khmer", } }, + { "sn", { 65, "shona", } }, + { "yo", { 66, "yoruba", } }, + { "so", { 67, "somali", } }, + { "af", { 68, "afrikaans", } }, + { "oc", { 69, "occitan", } }, + { "ka", { 70, "georgian", } }, + { "be", { 71, "belarusian", } }, + { "tg", { 72, "tajik", } }, + { "sd", { 73, "sindhi", } }, + { "gu", { 74, "gujarati", } }, + { "am", { 75, "amharic", } }, + { "yi", { 76, "yiddish", } }, + { "lo", { 77, "lao", } }, + { "uz", { 78, "uzbek", } }, + { "fo", { 79, "faroese", } }, + { "ht", { 80, "haitian creole", } }, + { "ps", { 81, "pashto", } }, + { "tk", { 82, "turkmen", } }, + { "nn", { 83, "nynorsk", } }, + { "mt", { 84, "maltese", } }, + { "sa", { 85, "sanskrit", } }, + { "lb", { 86, "luxembourgish", } }, + { "my", { 87, "myanmar", } }, + { "bo", { 88, "tibetan", } }, + { "tl", { 89, "tagalog", } }, + { "mg", { 90, "malagasy", } }, + { "as", { 91, "assamese", } }, + { "tt", { 92, "tatar", } }, + { "haw", { 93, "hawaiian", } }, + { "ln", { 94, "lingala", } }, + { "ha", { 95, "hausa", } }, + { "ba", { 96, "bashkir", } }, + { "jw", { 97, "javanese", } }, + { "su", { 98, "sundanese", } }, +}; + +static const size_t MB = 1ull*1024*1024; + +static const std::map MEM_REQ_SCRATCH0 = { + { MODEL_TINY, 62ull*MB }, + { MODEL_BASE, 80ull*MB }, + { MODEL_SMALL, 120ull*MB }, + { MODEL_MEDIUM, 158ull*MB }, + { MODEL_LARGE, 198ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH1 = { + { MODEL_TINY, 18ull*MB }, + { MODEL_BASE, 24ull*MB }, + { MODEL_SMALL, 36ull*MB }, + { MODEL_MEDIUM, 48ull*MB }, + { MODEL_LARGE, 60ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH2 = { + { MODEL_TINY, 4ull*MB }, + { MODEL_BASE, 4ull*MB }, + { MODEL_SMALL, 6ull*MB }, + { MODEL_MEDIUM, 7ull*MB }, + { MODEL_LARGE, 9ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH3 = { + { MODEL_TINY, 4ull*MB }, + { MODEL_BASE, 4ull*MB }, + { MODEL_SMALL, 6ull*MB }, + { MODEL_MEDIUM, 7ull*MB }, + { MODEL_LARGE, 9ull*MB }, +}; + +static const std::map> MEM_REQ_MODEL = { + { WSP_GGML_TYPE_F32, + { + { MODEL_TINY, 74ull*MB }, + { MODEL_BASE, 142ull*MB }, + { MODEL_SMALL, 466ull*MB }, + { MODEL_MEDIUM, 1464ull*MB }, + { MODEL_LARGE, 2952ull*MB }, + }, + }, + { WSP_GGML_TYPE_F16, + { + { MODEL_TINY, 74ull*MB }, + { MODEL_BASE, 142ull*MB }, + { MODEL_SMALL, 466ull*MB }, + { MODEL_MEDIUM, 1464ull*MB }, + { MODEL_LARGE, 2952ull*MB }, + }, + }, + { WSP_GGML_TYPE_Q4_0, + { + { MODEL_TINY, 26ull*MB }, + { MODEL_BASE, 50ull*MB }, + { MODEL_SMALL, 154ull*MB }, + { MODEL_MEDIUM, 470ull*MB }, + { MODEL_LARGE, 940ull*MB }, + }, + }, + { WSP_GGML_TYPE_Q4_1, + { + { MODEL_TINY, 32ull*MB }, + { MODEL_BASE, 58ull*MB }, + { MODEL_SMALL, 182ull*MB }, + { MODEL_MEDIUM, 562ull*MB }, + { MODEL_LARGE, 1124ull*MB }, + }, + }, + { WSP_GGML_TYPE_Q5_0, + { + { MODEL_TINY, 30ull*MB }, + { MODEL_BASE, 54ull*MB }, + { MODEL_SMALL, 170ull*MB }, + { MODEL_MEDIUM, 516ull*MB }, + { MODEL_LARGE, 1034ull*MB }, + }, + }, + { WSP_GGML_TYPE_Q5_1, + { + { MODEL_TINY, 32ull*MB }, + { MODEL_BASE, 58ull*MB }, + { MODEL_SMALL, 182ull*MB }, + { MODEL_MEDIUM, 562ull*MB }, + { MODEL_LARGE, 1124ull*MB }, + }, + }, + { WSP_GGML_TYPE_Q8_0, + { + { MODEL_TINY, 45ull*MB }, + { MODEL_BASE, 84ull*MB }, + { MODEL_SMALL, 268ull*MB }, + { MODEL_MEDIUM, 834ull*MB }, + { MODEL_LARGE, 1674ull*MB }, + }, + }, +}; + +static const std::map MEM_REQ_KV_SELF = { + { MODEL_TINY, 3ull*MB }, + { MODEL_BASE, 6ull*MB }, + { MODEL_SMALL, 16ull*MB }, + { MODEL_MEDIUM, 43ull*MB }, + { MODEL_LARGE, 71ull*MB }, +}; + +static const std::map MEM_REQ_KV_CROSS = { + { MODEL_TINY, 9ull*MB }, + { MODEL_BASE, 18ull*MB }, + { MODEL_SMALL, 53ull*MB }, + { MODEL_MEDIUM, 141ull*MB }, + { MODEL_LARGE, 235ull*MB }, +}; + +static const std::map MEM_REQ_ENCODE = { + { MODEL_TINY, 30ull*MB }, + { MODEL_BASE, 38ull*MB }, + { MODEL_SMALL, 56ull*MB }, + { MODEL_MEDIUM, 74ull*MB }, + { MODEL_LARGE, 94ull*MB }, +}; + +static const std::map MEM_REQ_DECODE = { + { MODEL_TINY, 3ull*MB }, + { MODEL_BASE, 5ull*MB }, + { MODEL_SMALL, 10ull*MB }, + { MODEL_MEDIUM, 18ull*MB }, + { MODEL_LARGE, 27ull*MB }, +}; + +struct whisper_mel { + int n_len; + int n_len_org; + int n_mel; + + std::vector data; +}; + +struct whisper_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +struct whisper_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 51864; + + std::map token_to_id; + std::map id_to_token; + + // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349 + id token_eot = 50256; + id token_sot = 50257; + // task tokens (used only for multilingual models) + id token_translate = 50357; + id token_transcribe = 50358; + // other special tokens + id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn + id token_prev = 50360; + id token_nosp = 50361; + id token_not = 50362; // no timestamps + id token_beg = 50363; // begin timestamps + + bool is_multilingual() const { + return n_vocab == 51865; + } +}; + +struct whisper_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector tokens; + + bool speaker_turn_next; +}; + +// medium +// hparams: { +// 'n_mels': 80, +// 'n_vocab': 51864, +// 'n_audio_ctx': 1500, +// 'n_audio_state': 1024, +// 'n_audio_head': 16, +// 'n_audio_layer': 24, +// 'n_text_ctx': 448, +// 'n_text_state': 1024, +// 'n_text_head': 16, +// 'n_text_layer': 24 +// } +// +// default hparams (Whisper tiny) +struct whisper_hparams { + int32_t n_vocab = 51864; + int32_t n_audio_ctx = 1500; + int32_t n_audio_state = 384; + int32_t n_audio_head = 6; + int32_t n_audio_layer = 4; + int32_t n_text_ctx = 448; + int32_t n_text_state = 384; + int32_t n_text_head = 6; + int32_t n_text_layer = 4; + int32_t n_mels = 80; + int32_t ftype = 1; +}; + +// audio encoding layer +struct whisper_layer_encoder { + // encoder.blocks.*.attn_ln + struct wsp_ggml_tensor * attn_ln_0_w; + struct wsp_ggml_tensor * attn_ln_0_b; + + // encoder.blocks.*.attn.out + struct wsp_ggml_tensor * attn_ln_1_w; + struct wsp_ggml_tensor * attn_ln_1_b; + + // encoder.blocks.*.attn.query + struct wsp_ggml_tensor * attn_q_w; + struct wsp_ggml_tensor * attn_q_b; + + // encoder.blocks.*.attn.key + struct wsp_ggml_tensor * attn_k_w; + + // encoder.blocks.*.attn.value + struct wsp_ggml_tensor * attn_v_w; + struct wsp_ggml_tensor * attn_v_b; + + // encoder.blocks.*.mlp_ln + struct wsp_ggml_tensor * mlp_ln_w; + struct wsp_ggml_tensor * mlp_ln_b; + + // encoder.blocks.*.mlp.0 + struct wsp_ggml_tensor * mlp_0_w; + struct wsp_ggml_tensor * mlp_0_b; + + // encoder.blocks.*.mlp.2 + struct wsp_ggml_tensor * mlp_1_w; + struct wsp_ggml_tensor * mlp_1_b; +}; + +// token decoding layer +struct whisper_layer_decoder { + // decoder.blocks.*.attn_ln + struct wsp_ggml_tensor * attn_ln_0_w; + struct wsp_ggml_tensor * attn_ln_0_b; + + // decoder.blocks.*.attn.out + struct wsp_ggml_tensor * attn_ln_1_w; + struct wsp_ggml_tensor * attn_ln_1_b; + + // decoder.blocks.*.attn.query + struct wsp_ggml_tensor * attn_q_w; + struct wsp_ggml_tensor * attn_q_b; + + // decoder.blocks.*.attn.key + struct wsp_ggml_tensor * attn_k_w; + + // decoder.blocks.*.attn.value + struct wsp_ggml_tensor * attn_v_w; + struct wsp_ggml_tensor * attn_v_b; + + // decoder.blocks.*.cross_attn_ln + struct wsp_ggml_tensor * cross_attn_ln_0_w; + struct wsp_ggml_tensor * cross_attn_ln_0_b; + + // decoder.blocks.*.cross_attn.out + struct wsp_ggml_tensor * cross_attn_ln_1_w; + struct wsp_ggml_tensor * cross_attn_ln_1_b; + + // decoder.blocks.*.cross_attn.query + struct wsp_ggml_tensor * cross_attn_q_w; + struct wsp_ggml_tensor * cross_attn_q_b; + + // decoder.blocks.*.cross_attn.key + struct wsp_ggml_tensor * cross_attn_k_w; + + // decoder.blocks.*.cross_attn.value + struct wsp_ggml_tensor * cross_attn_v_w; + struct wsp_ggml_tensor * cross_attn_v_b; + + // decoder.blocks.*.mlp_ln + struct wsp_ggml_tensor * mlp_ln_w; + struct wsp_ggml_tensor * mlp_ln_b; + + // decoder.blocks.*.mlp.0 + struct wsp_ggml_tensor * mlp_0_w; + struct wsp_ggml_tensor * mlp_0_b; + + // decoder.blocks.*.mlp.2 + struct wsp_ggml_tensor * mlp_1_w; + struct wsp_ggml_tensor * mlp_1_b; +}; + +struct whisper_kv_cache { + struct wsp_ggml_tensor * k; + struct wsp_ggml_tensor * v; + + struct wsp_ggml_context * ctx; + + std::vector buf; + + int n; // number of tokens currently in the cache +}; + +struct whisper_model { + e_model type = MODEL_UNKNOWN; + + whisper_hparams hparams; + whisper_filters filters; + + // encoder.positional_embedding + struct wsp_ggml_tensor * e_pe; + + // encoder.conv1 + struct wsp_ggml_tensor * e_conv_1_w; + struct wsp_ggml_tensor * e_conv_1_b; + + // encoder.conv2 + struct wsp_ggml_tensor * e_conv_2_w; + struct wsp_ggml_tensor * e_conv_2_b; + + // encoder.ln_post + struct wsp_ggml_tensor * e_ln_w; + struct wsp_ggml_tensor * e_ln_b; + + // decoder.positional_embedding + struct wsp_ggml_tensor * d_pe; + + // decoder.token_embedding + struct wsp_ggml_tensor * d_te; + + // decoder.ln + struct wsp_ggml_tensor * d_ln_w; + struct wsp_ggml_tensor * d_ln_b; + + std::vector layers_encoder; + std::vector layers_decoder; + + // context + struct wsp_ggml_context * ctx; + + // the model memory buffer is read-only and can be shared between processors + std::vector * buf; + + // tensors + int n_loaded; + std::map tensors; +}; + +struct whisper_sequence { + std::vector tokens; + + // the accumulated transcription in the current iteration (used to truncate the tokens array) + int result_len; + + double sum_logprobs_all; // the sum of the log probabilities of the tokens + double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens) + double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens + double score; // likelihood rank score +}; + +// TAGS: WHISPER_DECODER_INIT +struct whisper_decoder { + // each decoders keeps its own KV-cache + whisper_kv_cache kv_self; + + // the currently generated sequence of tokens + whisper_sequence sequence; + + int seek_delta; // the window shift found so far based on the decoded timestamp tokens + + bool failed; // has the current segment failed to decode? + bool completed; // has the decoder completed the current segment? + bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? + + // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector probs; + std::vector logits; + std::vector logprobs; + + std::vector tokens_tmp; // used for whisper_decode calls +}; + +struct whisper_state { + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_mel_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + // cross-attention KV cache for the decoders + // shared between all decoders + whisper_kv_cache kv_cross; + whisper_mel mel; + + whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; + + // memory buffers used by encode / decode contexts + std::vector buf_compute; + std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; + + int buf_last = 0; + size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 }; + + // decode output (2-dimensional array: [n_tokens][n_vocab]) + std::vector logits; + + std::vector result_all; + std::vector prompt_past; + + // work container used to avoid memory allocations + std::vector> logits_id; + + mutable std::mt19937 rng; // used for sampling at t > 0.0 + + int lang_id = 0; // english by default + + std::string path_model; // populated by whisper_init_from_file() +#ifdef WHISPER_USE_COREML + whisper_coreml_context * ctx_coreml = nullptr; +#endif + +#ifdef WHISPER_USE_OPENVINO + whisper_openvino_context * ctx_openvino = nullptr; +#endif + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg = 0; + int64_t t_last = 0; + whisper_token tid_last; + std::vector energy; // PCM signal energy + + // [EXPERIMENTAL] speed-up techniques + int32_t exp_n_audio_ctx = 0; // 0 - use default + + void use_buf(struct wsp_ggml_context * ctx, int i) { +#if defined(WHISPER_USE_SCRATCH) + size_t last_size = 0; + + if (i == -1) { + last_size = wsp_ggml_set_scratch(ctx, { 0, 0, nullptr, }); + } else { + auto & buf = buf_scratch[i]; + last_size = wsp_ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); + } + + if (buf_last >= 0) { + buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); + } + + buf_last = i; +#else + (void) i; + (void) ctx; +#endif + } + + size_t get_buf_max_mem(int i) const { +#if defined(WHISPER_USE_SCRATCH) + return buf_max_size[i]; +#else + (void) i; + return 0; +#endif + } +}; + +struct whisper_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + wsp_ggml_type wtype = wsp_ggml_type::WSP_GGML_TYPE_F16; // weight type (FP32 / FP16 / QX) + wsp_ggml_type itype = wsp_ggml_type::WSP_GGML_TYPE_F16; // intermediate type (FP32 or FP16) + + whisper_model model; + whisper_vocab vocab; + whisper_state * state = nullptr; + + std::string path_model; // populated by whisper_init_from_file() +}; + +static void whisper_default_log(const char * text) { + fprintf(stderr, "%s", text); +} + +static whisper_log_callback whisper_log = whisper_default_log; + +static void log(const char * fmt, ...) { + if (!whisper_log) return; + char buf[1024]; + va_list args; + va_start(args, fmt); + vsnprintf(buf, sizeof(buf), fmt, args); + whisper_log(buf); +} + +template +static void read_safe(whisper_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool kv_cache_init( + const struct whisper_hparams & hparams, + const size_t mem_bytes, + struct whisper_kv_cache & cache, + wsp_ggml_type wtype, + int n_ctx) { + cache.buf.resize(mem_bytes); + + struct wsp_ggml_init_params params = { + /*.mem_size =*/ cache.buf.size(), + /*.mem_buffer =*/ cache.buf.data(), + /*.no_alloc =*/ false, + }; + + cache.ctx = wsp_ggml_init(params); + + if (!cache.ctx) { + log("%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mem = n_text_layer*n_ctx; + const int n_elements = n_text_state*n_mem; + + cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static bool kv_cache_reinit(struct whisper_kv_cache & cache) { + WHISPER_ASSERT(cache.ctx); + + const int n_elements = wsp_ggml_nelements(cache.k); + WHISPER_ASSERT(n_elements == wsp_ggml_nelements(cache.v)); + + const wsp_ggml_type wtype = cache.k->type; + WHISPER_ASSERT(wtype == cache.v->type); + + WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*wsp_ggml_type_sizef(wtype)); + + struct wsp_ggml_init_params params = { + /*.mem_size =*/ cache.buf.size(), + /*.mem_buffer =*/ cache.buf.data(), + /*.no_alloc =*/ false, + }; + + cache.ctx = wsp_ggml_init(params); + + if (!cache.ctx) { + log("%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = wsp_ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static void kv_cache_free(struct whisper_kv_cache & cache) { + if (cache.ctx) { + wsp_ggml_free(cache.ctx); + cache.ctx = nullptr; + } +} + +// load the model from a ggml file +// +// file format: +// +// - hparams +// - pre-computed mel filters +// - vocab +// - weights +// +// see the convert-pt-to-ggml.py script for details +// +static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { + log("%s: loading model\n", __func__); + + const int64_t t_start_us = wsp_ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != WSP_GGML_FILE_MAGIC) { + log("%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + { + auto & hparams = model.hparams; + + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_text_ctx); + read_safe(loader, hparams.n_text_state); + read_safe(loader, hparams.n_text_head); + read_safe(loader, hparams.n_text_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.ftype); + + assert(hparams.n_text_state == hparams.n_audio_state); + + if (hparams.n_audio_layer == 4) { + model.type = e_model::MODEL_TINY; + } + + if (hparams.n_audio_layer == 6) { + model.type = e_model::MODEL_BASE; + } + + if (hparams.n_audio_layer == 12) { + model.type = e_model::MODEL_SMALL; + } + + if (hparams.n_audio_layer == 24) { + model.type = e_model::MODEL_MEDIUM; + } + + if (hparams.n_audio_layer == 32) { + model.type = e_model::MODEL_LARGE; + } + + const int32_t qntvr = hparams.ftype / WSP_GGML_QNT_VERSION_FACTOR; + + hparams.ftype %= WSP_GGML_QNT_VERSION_FACTOR; + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + wctx.wtype = wsp_ggml_ftype_to_wsp_ggml_type((wsp_ggml_ftype) (model.hparams.ftype)); + if (wctx.wtype == WSP_GGML_TYPE_COUNT) { + log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); + return false; + } + + const size_t scale = model.hparams.ftype ? 1 : 2; + + log("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + log("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + log("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + log("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + log("%s: n_text_state = %d\n", __func__, hparams.n_text_state); + log("%s: n_text_head = %d\n", __func__, hparams.n_text_head); + log("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + log("%s: n_mels = %d\n", __func__, hparams.n_mels); + log("%s: ftype = %d\n", __func__, model.hparams.ftype); + log("%s: qntvr = %d\n", __func__, qntvr); + log("%s: type = %d\n", __func__, model.type); + + // print memory requirements + { + // this is the total memory required to run the inference + const size_t mem_required = + MEM_REQ_SCRATCH0.at(model.type) + + MEM_REQ_SCRATCH1.at(model.type) + + MEM_REQ_SCRATCH2.at(model.type) + + MEM_REQ_SCRATCH3.at(model.type) + + scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) + + scale*MEM_REQ_KV_CROSS.at(model.type) + + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); + + // this is the memory required by one decoder + const size_t mem_required_decoder = + scale*MEM_REQ_KV_SELF.at(model.type); + + log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); + } + + // initialize all memory buffers + // always have at least one decoder + + wctx.model.buf = new std::vector(); + wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type)); + + // we skip initialization of the state until it is needed + // because it might be that state will always be provided externally. + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fft); + + filters.data.resize(filters.n_mel * filters.n_fft); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + //if (n_vocab != model.hparams.n_vocab) { + // log("%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + // return false; + //} + + std::string word; + std::vector tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + // seems like we have an empty-string token in multi-language models (i = 50256) + //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + + //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + } + + vocab.n_vocab = model.hparams.n_vocab; + if (vocab.is_multilingual()) { + vocab.token_eot++; + vocab.token_sot++; + vocab.token_translate++; + vocab.token_transcribe++; + vocab.token_solm++; + vocab.token_prev++; + vocab.token_nosp++; + vocab.token_not++; + vocab.token_beg++; + } + + if (n_vocab < model.hparams.n_vocab) { + log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); + for (int i = n_vocab; i < model.hparams.n_vocab; i++) { + if (i > vocab.token_beg) { + word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; + } else if (i == vocab.token_eot) { + word = "[_EOT_]"; + } else if (i == vocab.token_sot) { + word = "[_SOT_]"; + } else if (i == vocab.token_solm) { + word = "[_SOLM_]"; + } else if (i == vocab.token_prev) { + word = "[_PREV_]"; + } else if (i == vocab.token_nosp) { + word = "[_NOSP_]"; + } else if (i == vocab.token_not) { + word = "[_NOT_]"; + } else if (i == vocab.token_beg) { + word = "[_BEG_]"; + } else { + word = "[_extra_token_" + std::to_string(i) + "]"; + } + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + } + + size_t ctx_size = 0; + + const wsp_ggml_type wtype = wctx.wtype; + const wsp_ggml_type vtype = wctx.wtype == WSP_GGML_TYPE_F32 ? WSP_GGML_TYPE_F32 : WSP_GGML_TYPE_F16; // conv type + + { + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + // encoder + { + ctx_size += n_audio_ctx*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_pe; + + ctx_size += 3*n_mels*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_1_w + ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_1_b + + ctx_size += 3*n_audio_state*n_audio_state*wsp_ggml_type_sizef(vtype); // e_conv_2_w + ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_conv_2_b + + ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_w; + ctx_size += n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // e_ln_b; + } + + // decoder + { + ctx_size += n_text_ctx*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_pe; + + ctx_size += n_vocab*n_text_state*wsp_ggml_type_sizef(wtype); // d_te; + + ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_w; + ctx_size += n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32); // d_ln_b; + } + + // encoder layers + { + ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b + + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w + ctx_size += n_audio_layer*( 4*n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b + + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w + ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b + + ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_audio_layer*(n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_q_w + ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_k_w + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_v_w + ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w + ctx_size += n_audio_layer*( n_audio_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b + } + + // decoder layers + { + ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_ln_b + + ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_0_w + ctx_size += n_text_layer*( 4*n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_0_b + + ctx_size += n_text_layer*(4*n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // mlp_1_w + ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // mlp_1_b + + ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_0_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_q_w + ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_q_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_k_w + + ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_v_w + ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_v_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // attn_ln_1_b + // + ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_0_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_q_w + ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_q_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_k_w + + ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_v_w + ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_v_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*wsp_ggml_type_sizef(wtype)); // cross_attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*wsp_ggml_type_sizef(WSP_GGML_TYPE_F32)); // cross_attn_ln_1_b + } + + ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead + + log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + } + + // create the ggml context + { + struct wsp_ggml_init_params params = { + /*.mem_size =*/ wctx.model.buf->size(), + /*.mem_buffer =*/ wctx.model.buf->data(), + /*.no_alloc =*/ false, + }; + + model.ctx = wsp_ggml_init(params); + if (!model.ctx) { + log("%s: wsp_ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + auto & ctx = model.ctx; + + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + model.layers_encoder.resize(n_audio_layer); + model.layers_decoder.resize(n_text_layer); + + // encoder + { + model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx); + + model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); + model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state); + + model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state); + + model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.positional_embedding"] = model.e_pe; + + model.tensors["encoder.conv1.weight"] = model.e_conv_1_w; + model.tensors["encoder.conv1.bias"] = model.e_conv_1_b; + + model.tensors["encoder.conv2.weight"] = model.e_conv_2_w; + model.tensors["encoder.conv2.bias"] = model.e_conv_2_b; + + model.tensors["encoder.ln_post.weight"] = model.e_ln_w; + model.tensors["encoder.ln_post.bias"] = model.e_ln_b; + + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers_encoder[i]; + + layer.mlp_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + layer.mlp_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + + layer.mlp_0_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state); + layer.mlp_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 4*n_audio_state); + + layer.mlp_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state); + layer.mlp_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_0_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + layer.attn_ln_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + + layer.attn_q_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_q_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + + layer.attn_k_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + + layer.attn_v_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_v_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_ln_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + } + } + + // decoder + { + model.d_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_text_state, n_text_ctx); + + model.d_te = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab); + + model.d_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + model.d_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.positional_embedding"] = model.d_pe; + + model.tensors["decoder.token_embedding.weight"] = model.d_te; + + model.tensors["decoder.ln.weight"] = model.d_ln_w; + model.tensors["decoder.ln.bias"] = model.d_ln_b; + + for (int i = 0; i < n_text_layer; ++i) { + auto & layer = model.layers_decoder[i]; + + layer.mlp_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + layer.mlp_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + layer.mlp_0_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state); + layer.mlp_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 4*n_text_state); + + layer.mlp_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state); + layer.mlp_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + layer.attn_ln_0_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + layer.attn_ln_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + layer.attn_q_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_q_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + layer.attn_k_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.attn_v_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_v_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + layer.attn_ln_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_ln_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_0_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + layer.cross_attn_ln_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + layer.cross_attn_q_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_q_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + layer.cross_attn_k_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.cross_attn_v_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_v_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_ln_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; + } + } + } + + // load weights + { + size_t total_size = 0; + + model.n_loaded = 0; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (model.tensors.find(name) == model.tensors.end()) { + log("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (wsp_ggml_nelements(tensor) != nelements) { + log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); + return false; + } + + const size_t bpe = wsp_ggml_type_size(wsp_ggml_type(ttype)); + + if ((nelements*bpe)/wsp_ggml_blck_size(tensor->type) != wsp_ggml_nbytes(tensor)) { + log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), wsp_ggml_nbytes(tensor), nelements*bpe); + return false; + } + + loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + + //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1024.0/1024.0); + total_size += wsp_ggml_nbytes(tensor); + model.n_loaded++; + } + + log("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + + if (model.n_loaded == 0) { + log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (model.n_loaded != (int) model.tensors.size()) { + log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return false; + } + } + + wctx.t_load_us = wsp_ggml_time_us() - t_start_us; + + return true; +} + +// evaluate the encoder with the given state +// +// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder +// part of the transformer model and returns the encoded features +// +// - wctx: the model +// - wstate: the state of the encoder +// - n_threads: number of threads to use +// - mel_offset: offset in the mel spectrogram (i.e. audio offset) +// +static bool whisper_encode_internal( + whisper_context & wctx, + whisper_state & wstate, + const int mel_offset, + const int n_threads){ + + const int64_t t_start_us = wsp_ggml_time_us(); + + const auto & model = wctx.model; + const auto & mel_inp = wstate.mel; + const auto & hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + const int n_layer = hparams.n_audio_layer; + + const int n_mels = hparams.n_mels; + assert(mel_inp.n_mel == n_mels); + + struct wsp_ggml_init_params params = { + /*.mem_size =*/ wstate.buf_compute.size(), + /*.mem_buffer =*/ wstate.buf_compute.data(), + /*.no_alloc =*/ false, + }; + + struct wsp_ggml_context * ctx0 = wsp_ggml_init(params); + + wstate.use_buf(ctx0, 0); + + struct wsp_ggml_tensor * mel = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, 2*n_ctx, n_mels); + assert(mel->type == WSP_GGML_TYPE_F32); + { + float * dst = (float *) mel->data; + memset(dst, 0, wsp_ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + + for (int j = 0; j < mel_inp.n_mel; ++j) { + for (int i = i0; i < i1; ++i) { + dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; + } + } + } + + struct wsp_ggml_tensor * cur; + +#ifndef WHISPER_USE_COREML + const bool use_coreml = false; +#else + const bool use_coreml = wstate.ctx_coreml != nullptr; +#endif + +#ifndef WHISPER_USE_OPENVINO + const bool use_openvino = false; +#else + const bool use_openvino = wstate.ctx_openvino != nullptr; +#endif + + if (!use_coreml && !use_openvino) { + // convolution + gelu + { + wstate.use_buf(ctx0, 1); + + cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); + cur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, + model.e_conv_1_b, + cur), + cur); + + cur = wsp_ggml_gelu(ctx0, cur); + + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); + cur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, + model.e_conv_2_b, + cur), + cur); + + cur = wsp_ggml_gelu(ctx0, cur); + } + + wstate.use_buf(ctx0, 3); + + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) + //static int iter = -1; + //const int n_iter = 1500/n_ctx; + + //iter = (iter + 1) % n_iter; + + //if (iter == 0) { + // memset(model.memory_cross_k->data, 0, wsp_ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, wsp_ggml_nbytes(model.memory_cross_v)); + //} + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0]*wsp_ggml_element_size(model.e_pe)*n_ctx*iter; + + struct wsp_ggml_tensor * e_pe = wsp_ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + + cur = wsp_ggml_add(ctx0, e_pe, wsp_ggml_transpose(ctx0, cur)); + + // =================================================================== + + // original: + //cur = wsp_ggml_add(ctx0, model.e_pe, wsp_ggml_transpose(ctx0, cur)); + + struct wsp_ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_encoder[il]; + + // norm + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_norm(ctx0, inpL); + + // cur = ln_0_w*cur + ln_0_b + cur = wsp_ggml_add(ctx0, + wsp_ggml_mul(ctx0, + wsp_ggml_repeat(ctx0, layer.attn_ln_0_w, cur), + cur), + wsp_ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); + } + + // self-attention + { + wstate.use_buf(ctx0, 1); + + struct wsp_ggml_tensor * Qcur = wsp_ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, + layer.attn_q_b, + Qcur), + Qcur); + + //Qcur = wsp_ggml_scale_inplace(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // note: no bias for Key + struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + //Kcur = wsp_ggml_scale_inplace(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, + layer.attn_v_b, + Vcur), + Vcur); + + // ------ + + wstate.use_buf(ctx0, 0); + +#ifdef WHISPER_USE_FLASH_ATTN + struct wsp_ggml_tensor * Q = + wsp_ggml_permute(ctx0, + wsp_ggml_cpy(ctx0, + Qcur, + wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct wsp_ggml_tensor * K = + wsp_ggml_permute(ctx0, + wsp_ggml_cpy(ctx0, + Kcur, + wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct wsp_ggml_tensor * V = + wsp_ggml_cpy(ctx0, + wsp_ggml_permute(ctx0, + wsp_ggml_reshape_3d(ctx0, + Vcur, + n_state/n_head, n_head, n_ctx), + 1, 2, 0, 3), + wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)); + + struct wsp_ggml_tensor * KQV = wsp_ggml_flash_attn(ctx0, Q, K, V, false); +#else + struct wsp_ggml_tensor * Q = + wsp_ggml_permute(ctx0, + wsp_ggml_cpy(ctx0, + Qcur, + wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct wsp_ggml_tensor * K = + wsp_ggml_permute(ctx0, + wsp_ggml_cpy(ctx0, + Kcur, + wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q); + + struct wsp_ggml_tensor * KQ_scaled = + wsp_ggml_scale_inplace(ctx0, + KQ, + wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + ); + + struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ_scaled); + + struct wsp_ggml_tensor * V = + wsp_ggml_cpy(ctx0, + wsp_ggml_permute(ctx0, + wsp_ggml_reshape_3d(ctx0, + Vcur, + n_state/n_head, n_head, n_ctx), + 1, 2, 0, 3), + wsp_ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head) + ); + + struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max); +#endif + struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + wstate.use_buf(ctx0, 1); + + cur = wsp_ggml_cpy(ctx0, + KQV_merged, + wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx)); + } + + // projection + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + wstate.use_buf(ctx0, 1); + + cur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, layer.attn_ln_1_b, cur), + cur); + } + + wstate.use_buf(ctx0, 2); + + // add the input + cur = wsp_ggml_add(ctx0, cur, inpL); + + struct wsp_ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_norm(ctx0, inpFF); + + wstate.use_buf(ctx0, 1); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = wsp_ggml_add(ctx0, + wsp_ggml_mul(ctx0, + wsp_ggml_repeat(ctx0, layer.mlp_ln_w, cur), + cur), + wsp_ggml_repeat(ctx0, layer.mlp_ln_b, cur)); + } + +#ifdef WHISPER_USE_FLASH_FF + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_flash_ff(ctx0, + wsp_ggml_cpy(ctx0, cur, wsp_ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); +#else + wstate.use_buf(ctx0, 0); + + // fully connected + cur = wsp_ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + wstate.use_buf(ctx0, 1); + + cur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, layer.mlp_0_b, cur), + cur); + + wstate.use_buf(ctx0, 0); + + // GELU activation + cur = wsp_ggml_gelu(ctx0, cur); + + wstate.use_buf(ctx0, 1); + + // projection + cur = wsp_ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); +#endif + } + + wstate.use_buf(ctx0, 3); + + inpL = wsp_ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_norm(ctx0, cur); + + wstate.use_buf(ctx0, 1); + + // cur = ln_f_g*cur + ln_f_b + cur = wsp_ggml_add(ctx0, + wsp_ggml_mul(ctx0, + wsp_ggml_repeat(ctx0, model.e_ln_w, cur), + cur), + wsp_ggml_repeat(ctx0, model.e_ln_b, cur)); + } + + wstate.use_buf(ctx0, -1); + + // run the computation + { + struct wsp_ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + wsp_ggml_build_forward_expand(&gf, cur); + wsp_ggml_graph_compute(ctx0, &gf); + + //wsp_ggml_graph_print(&gf); + } + } +#ifdef WHISPER_USE_COREML + else if (use_coreml) { + wstate.use_buf(ctx0, -1); + + cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx); + + whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data); + } +#endif +#ifdef WHISPER_USE_OPENVINO + else if (use_openvino) { + wstate.use_buf(ctx0, -1); + + cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx); + + if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) { + return false; + } + } +#endif + + // cur + //{ + // printf("ne0 = %d\n", cur->ne[0]); + // printf("ne1 = %d\n", cur->ne[1]); + // for (int i = 0; i < 10; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("... "); + // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("\n"); + //} + + // pre-compute cross-attention memory + { + struct wsp_ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + // TODO: hack to disconnect the encoded features from the previous graph + cur->op = WSP_GGML_OP_NONE; + cur->src0 = nullptr; + cur->src1 = nullptr; + + for (int il = 0; il < model.hparams.n_text_layer; ++il) { + auto& layer = model.layers_decoder[il]; + + wstate.use_buf(ctx0, 0); + + struct wsp_ggml_tensor* Kcross = wsp_ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); + + Kcross = wsp_ggml_scale_inplace(ctx0, Kcross, wsp_ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25))); + + wstate.use_buf(ctx0, 1); + + struct wsp_ggml_tensor* Vcross = wsp_ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); + + Vcross = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, + layer.cross_attn_v_b, + Vcross), + Vcross); + + wstate.use_buf(ctx0, -1); + + Vcross = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + + struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (wsp_ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, + ( n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v), + (il*n_ctx)*wsp_ggml_element_size(wstate.kv_cross.v)*n_state); + + wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Kcross, k)); + wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Vcross, v)); + } + + wsp_ggml_graph_compute(ctx0, &gf); + //wsp_ggml_graph_print(&gf); + } + + //////////////////////////////////////////////////////////////////////////// + + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // wsp_ggml_used_mem(ctx0)/1024.0/1024.0, + // wstate.get_buf_max_mem(0)/1024.0/1024.0, + // wstate.get_buf_max_mem(1)/1024.0/1024.0, + // wstate.get_buf_max_mem(2)/1024.0/1024.0, + // wstate.get_buf_max_mem(3)/1024.0/1024.0); + + wsp_ggml_free(ctx0); + + wstate.t_encode_us += wsp_ggml_time_us() - t_start_us; + wstate.n_encode++; + + return true; +} + +// evaluate the decoder +// +// given text prompt + audio features -> computes the logits for the next token +// +// - model: the model +// - n_threads: number of threads to use +// - tokens: text prompt +// - n_tokens: number of tokens in the prompt +// - n_past: number of past tokens to prefix the prompt with +// +static bool whisper_decode_internal( + whisper_context & wctx, + whisper_state & wstate, + whisper_decoder & decoder, + const whisper_token * tokens, + const int n_tokens, + const int n_past, + const int n_threads) { + const int64_t t_start_us = wsp_ggml_time_us(); + + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + auto & kv_self = decoder.kv_self; + + WHISPER_ASSERT(!!kv_self.ctx); + + auto & logits_out = wstate.logits; + + const int n_vocab = hparams.n_vocab; + + const int n_ctx = hparams.n_text_ctx; + const int n_state = hparams.n_text_state; + const int n_head = hparams.n_text_head; + const int n_layer = hparams.n_text_layer; + + const int N = n_tokens; + const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + + //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + + struct wsp_ggml_init_params params = { + /*.mem_size =*/ wstate.buf_compute.size(), + /*.mem_buffer =*/ wstate.buf_compute.data(), + /*.no_alloc =*/ false, + }; + + struct wsp_ggml_context * ctx0 = wsp_ggml_init(params); + + struct wsp_ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + struct wsp_ggml_tensor * embd = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N); + memcpy(embd->data, tokens, N*wsp_ggml_element_size(embd)); + + struct wsp_ggml_tensor * position = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_I32, N); + for (int i = 0; i < N; ++i) { + ((int32_t *) position->data)[i] = n_past + i; + } + + wstate.use_buf(ctx0, 3); + + // token encoding + position encoding + struct wsp_ggml_tensor * cur = + wsp_ggml_add(ctx0, + wsp_ggml_get_rows(ctx0, model.d_te, embd), + wsp_ggml_get_rows(ctx0, model.d_pe, position)); + + struct wsp_ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_decoder[il]; + + // norm + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_norm(ctx0, inpL); + + // cur = ln_0_w*cur + ln_0_b + cur = wsp_ggml_add(ctx0, + wsp_ggml_mul(ctx0, + wsp_ggml_repeat(ctx0, layer.attn_ln_0_w, cur), + cur), + wsp_ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); + } + + // self-attention + { + struct wsp_ggml_tensor * Qcur = wsp_ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, + layer.attn_q_b, + Qcur), + Qcur); + + Qcur = wsp_ggml_scale_inplace(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // note: no bias for Key + struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + Kcur = wsp_ggml_scale_inplace(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // store key and value to memory + { + struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, + layer.attn_v_b, + Vcur), + Vcur); + + Vcur = wsp_ggml_transpose(ctx0, wsp_ggml_reshape_2d(ctx0, Vcur, n_state, N)); + + struct wsp_ggml_tensor * k = wsp_ggml_view_1d(ctx0, kv_self.k, N*n_state, (wsp_ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct wsp_ggml_tensor * v = wsp_ggml_view_2d(ctx0, kv_self.v, N, n_state, + ( n_ctx)*wsp_ggml_element_size(kv_self.v), + (il*n_ctx)*wsp_ggml_element_size(kv_self.v)*n_state + n_past*wsp_ggml_element_size(kv_self.v)); + + wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Kcur, k)); + wsp_ggml_build_forward_expand(&gf, wsp_ggml_cpy(ctx0, Vcur, v)); + } + + // ------ + + wstate.use_buf(ctx0, 0); + + struct wsp_ggml_tensor * Q = + wsp_ggml_permute(ctx0, + wsp_ggml_cpy(ctx0, + Qcur, + wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct wsp_ggml_tensor * K = + wsp_ggml_permute(ctx0, + wsp_ggml_reshape_3d(ctx0, + wsp_ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*wsp_ggml_element_size(kv_self.k)*n_state), + n_state/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + wstate.use_buf(ctx0, 1); + + // K * Q + struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q); + + //struct wsp_ggml_tensor * KQ_scaled = + // wsp_ggml_scale_inplace(ctx0, + // KQ, + // wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + // ); + + struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf_inplace(ctx0, KQ, n_past); + + struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ_masked); + + struct wsp_ggml_tensor * V = + wsp_ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_state/n_head, n_head, + n_ctx*wsp_ggml_element_size(kv_self.v), + n_ctx*wsp_ggml_element_size(kv_self.v)*n_state/n_head, + il*n_ctx*wsp_ggml_element_size(kv_self.v)*n_state); + + struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = wsp_ggml_cpy(ctx0, + KQV_merged, + wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, N)); + } + + // projection + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + wstate.use_buf(ctx0, 1); + + cur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, layer.attn_ln_1_b, cur), + cur); + } + + wstate.use_buf(ctx0, 2); + + // add the input + struct wsp_ggml_tensor * inpCA = wsp_ggml_add(ctx0, cur, inpL); + + // norm + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_norm(ctx0, inpCA); // note: we use inpCA here + + // cur = ln_0_w*cur + ln_0_b + cur = wsp_ggml_add(ctx0, + wsp_ggml_mul(ctx0, + wsp_ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur), + cur), + wsp_ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur)); + } + + // cross-attention + { + struct wsp_ggml_tensor * Qcur = wsp_ggml_mul_mat(ctx0, + layer.cross_attn_q_w, + cur); + + Qcur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, + layer.cross_attn_q_b, + Qcur), + Qcur); + + Qcur = wsp_ggml_scale_inplace(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // Kcross is already scaled + struct wsp_ggml_tensor * Kcross = + wsp_ggml_reshape_3d(ctx0, + wsp_ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*wsp_ggml_element_size(wstate.kv_cross.k)*n_state), + n_state/n_head, n_head, M); + + //struct wsp_ggml_tensor * Vcross = + // wsp_ggml_reshape_3d(ctx0, + // wsp_ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*wsp_ggml_element_size(wstate.kv_cross.v)*n_state), + // n_state/n_head, n_head, M); + + //struct wsp_ggml_tensor * V_trans = + // wsp_ggml_cpy(ctx0, + // wsp_ggml_permute(ctx0, Vcross, 1, 2, 0, 3), + // wsp_ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head)); + + struct wsp_ggml_tensor * V = + wsp_ggml_view_3d(ctx0, wstate.kv_cross.v, + M, n_state/n_head, n_head, + M*wsp_ggml_element_size(wstate.kv_cross.v), + M*wsp_ggml_element_size(wstate.kv_cross.v)*n_state/n_head, + il*M*wsp_ggml_element_size(wstate.kv_cross.v)*n_state); + + // ------ + + struct wsp_ggml_tensor * Q = + wsp_ggml_permute(ctx0, + wsp_ggml_cpy(ctx0, + Qcur, + wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct wsp_ggml_tensor * K = wsp_ggml_permute(ctx0, Kcross, 0, 2, 1, 3); + + // K * Q + struct wsp_ggml_tensor * KQ = wsp_ggml_mul_mat(ctx0, K, Q); + + //struct wsp_ggml_tensor * KQ_scaled = + // wsp_ggml_scale_inplace(ctx0, + // KQ, + // wsp_ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + // ); + + // no masking for cross-attention + //struct wsp_ggml_tensor * KQ_masked = wsp_ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + struct wsp_ggml_tensor * KQ_soft_max = wsp_ggml_soft_max_inplace(ctx0, KQ); + + struct wsp_ggml_tensor * KQV = wsp_ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct wsp_ggml_tensor * KQV_merged = wsp_ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_state, N) + cur = wsp_ggml_cpy(ctx0, + KQV_merged, + wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, N)); + } + + // projection + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_mul_mat(ctx0, + layer.cross_attn_ln_1_w, + cur); + + wstate.use_buf(ctx0, 1); + + cur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), + cur); + } + + wstate.use_buf(ctx0, 2); + + // add the input + cur = wsp_ggml_add(ctx0, cur, inpCA); + + struct wsp_ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_norm(ctx0, inpFF); + + wstate.use_buf(ctx0, 1); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = wsp_ggml_add(ctx0, + wsp_ggml_mul(ctx0, + wsp_ggml_repeat(ctx0, layer.mlp_ln_w, cur), + cur), + wsp_ggml_repeat(ctx0, layer.mlp_ln_b, cur)); + } + + wstate.use_buf(ctx0, 0); + + // fully connected + cur = wsp_ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + wstate.use_buf(ctx0, 1); + + cur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, layer.mlp_0_b, cur), + cur); + + wstate.use_buf(ctx0, 0); + + // GELU activation + cur = wsp_ggml_gelu(ctx0, cur); + + wstate.use_buf(ctx0, 1); + + // projection + cur = wsp_ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_add(ctx0, + wsp_ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); + } + + wstate.use_buf(ctx0, 3); + + inpL = wsp_ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + wstate.use_buf(ctx0, 0); + + cur = wsp_ggml_norm(ctx0, cur); + + wstate.use_buf(ctx0, 1); + + cur = wsp_ggml_add(ctx0, + wsp_ggml_mul(ctx0, + wsp_ggml_repeat(ctx0, model.d_ln_w, cur), + cur), + wsp_ggml_repeat(ctx0, model.d_ln_b, cur)); + } + + wstate.use_buf(ctx0, 0); + + // compute logits only for the last token + // comment this line to compute logits for all N tokens + // might be useful in the future + cur = wsp_ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + + struct wsp_ggml_tensor * logits = wsp_ggml_mul_mat(ctx0, model.d_te, cur); + + wstate.use_buf(ctx0, -1); + + // run the computation + { + wsp_ggml_build_forward_expand(&gf, logits); + wsp_ggml_graph_compute (ctx0, &gf); + } + + // extract logits for all N tokens + //logits_out.resize(N*n_vocab); + //memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*N*n_vocab); + + // extract logits only for the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), wsp_ggml_get_data(logits), sizeof(float)*n_vocab); + + if (N > 1) { + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // wsp_ggml_used_mem(ctx0)/1024.0/1024.0, + // wstate.get_buf_max_mem(0)/1024.0/1024.0, + // wstate.get_buf_max_mem(1)/1024.0/1024.0, + // wstate.get_buf_max_mem(2)/1024.0/1024.0, + // wstate.get_buf_max_mem(3)/1024.0/1024.0); + } + + wsp_ggml_free(ctx0); + + wstate.t_decode_us += wsp_ggml_time_us() - t_start_us; + wstate.n_decode++; + + return true; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +static std::string to_timestamp(int64_t t, bool comma = false) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); + + return std::string(buf); +} + +#define SIN_COS_N_COUNT WHISPER_N_FFT +static float sin_vals[SIN_COS_N_COUNT]; +static float cos_vals[SIN_COS_N_COUNT]; + +// In FFT, we frequently use sine and cosine operations with the same values. +// We can use precalculated values to speed up the process. +static void fill_sin_cos_table() { + static bool is_filled = false; + if (is_filled) return; + for (int i = 0; i < SIN_COS_N_COUNT; i++) { + double theta = (2*M_PI*i)/SIN_COS_N_COUNT; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + is_filled = true; +} + +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const std::vector & in, std::vector & out) { + int N = in.size(); + + out.resize(N*2); + const int sin_cos_step = SIN_COS_N_COUNT / N; + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N + re += in[n]*cos_vals[idx]; // cos(t) + im -= in[n]*sin_vals[idx]; // sin(t) + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(const std::vector & in, std::vector & out) { + out.resize(in.size()*2); + + int N = in.size(); + + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + if (N%2 == 1) { + dft(in, out); + return; + } + + std::vector even; + std::vector odd; + + even.reserve(N/2); + odd.reserve(N/2); + + for (int i = 0; i < N; i++) { + if (i % 2 == 0) { + even.push_back(in[i]); + } else { + odd.push_back(in[i]); + } + } + + std::vector even_fft; + std::vector odd_fft; + + fft(even, even_fft); + fft(odd, odd_fft); + + const int sin_cos_step = SIN_COS_N_COUNT / N; + for (int k = 0; k < N/2; k++) { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cos_vals[idx]; // cos(t) + float im = -sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +static bool hann_window(int length, bool periodic, std::vector & output) { + if (output.size() < length) { + output.resize(length); + } + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset))); + } + + return true; +} + +static void log_mel_spectrogram_worker_thread(int ith, const std::vector & hann, const std::vector & samples, + int n_samples, int frame_size, int frame_step, int n_threads, + const whisper_filters & filters, whisper_mel & mel) { + std::vector fft_in(frame_size, 0.0); + std::vector fft_out(2 * frame_step); + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + int n_fft = 1 + (frame_size / 2); + int i = ith; + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + const int offset = i * frame_step; + + // apply Hanning window (~10% faster) + for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { + fft_in[j] = hann[j] * samples[offset + j]; + } + // fill the rest with zeros + if (n_samples - offset < frame_size) { + std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); + } + + // FFT + fft(fft_in, fft_out); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < frame_size; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fft - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j * n_fft + k + 0] + + fft_out[k + 1] * filters.data[j * n_fft + k + 1] + + fft_out[k + 2] * filters.data[j * n_fft + k + 2] + + fft_out[k + 3] * filters.data[j * n_fft + k + 3]; + } + + // handle n_fft remainder + for (; k < n_fft; k++) { + sum += fft_out[k] * filters.data[j * n_fft + k]; + } + + sum = log10(std::max(sum, 1e-10)); + + mel.data[j * mel.n_len + i] = sum; + } + } + + // Otherwise fft_out are all zero + double sum = log10(1e-10); + for (; i < mel.n_len; i += n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[j * mel.n_len + i] = sum; + } + } +} + +// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 +static bool log_mel_spectrogram( + whisper_state & wstate, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const whisper_filters & filters, + const bool debug, + whisper_mel & mel) { + const int64_t t_start_us = wsp_ggml_time_us(); + + // Hanning window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + std::vector hann; + hann_window(frame_size, true, hann); + + + // Calculate the length of padding + int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; + int64_t stage_2_pad = frame_size / 2; + + // Initialize a vector and copy data from C array to it. + std::vector samples_padded; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + + // reflective pad 200 samples at the beginning of audio + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + + mel.n_mel = n_mel; + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 + // Calculate number of frames + remove the last frame + mel.n_len = (samples_padded.size() - frame_size) / frame_step; + // Calculate semi-padded sample length to ensure compatibility + mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; + mel.data.resize(mel.n_mel * mel.n_len); + + + { + std::vector workers(n_threads - 1); + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw] = std::thread( + log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded, + n_samples + stage_2_pad, frame_size, frame_step, n_threads, + std::cref(filters), std::ref(mel)); + } + + // main thread + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] > mmax) { + mmax = mel.data[i]; + } + } + + mmax -= 8.0; + + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] < mmax) { + mel.data[i] = mmax; + } + + mel.data[i] = (mel.data[i] + 4.0)/4.0; + } + + wstate.t_mel_us += wsp_ggml_time_us() - t_start_us; + + // Dump log_mel_spectrogram + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } + + return true; +} + +// split text into tokens +// +// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 +// +// Regex (Python): +// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +// +// Regex (C++): +// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" +// +static std::vector tokenize(const whisper_vocab & vocab, const std::string & text) { + std::vector words; + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector tokens; + for (const auto & word : words) { + if (word.empty()) continue; + + int i = 0; + int n = word.size(); + while (i < n) { + int j = n; + bool found = false; + while (j > i) { + auto sub = word.substr(i, j-i); + auto it = vocab.token_to_id.find(sub); + if (it != vocab.token_to_id.end()) { + tokens.push_back(it->second); + i = j; + found = true; + break; + } + --j; + } + if (!found) { + log("unknown token\n"); + ++i; + } + } + } + + return tokens; +} + +// +// interface implementation +// + +#ifdef WHISPER_USE_COREML +// replace .bin with -encoder.mlmodelc +static std::string whisper_get_coreml_path_encoder(std::string path_bin) { + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) { + path_bin = path_bin.substr(0, pos); + } + + // match "-qx_x" + pos = path_bin.rfind('-'); + if (pos != std::string::npos) { + auto sub = path_bin.substr(pos); + if (sub.size() == 5 && sub[1] == 'q' && sub[3] == '_') { + path_bin = path_bin.substr(0, pos); + } + } + + path_bin += "-encoder.mlmodelc"; + + return path_bin; +} +#endif + +#ifdef WHISPER_USE_OPENVINO +// replace .bin with-encoder-openvino.xml +static std::string whisper_openvino_get_path_encoder(std::string path_bin) { + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) { + path_bin = path_bin.substr(0, pos); + } + + path_bin += "-encoder-openvino.xml"; + + return path_bin; +} + +static std::string whisper_openvino_get_path_cache(std::string path_bin) { + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) { + path_bin = path_bin.substr(0, pos); + } + + path_bin += "-encoder-openvino-cache"; + + return path_bin; +} +#endif + +struct whisper_state * whisper_init_state(whisper_context * ctx) { + fill_sin_cos_table(); + whisper_state * state = new whisper_state; + + const size_t scale = ctx->model.hparams.ftype ? 1 : 2; + + if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) { + log("%s: kv_cache_init() failed for self-attention cache\n", __func__); + delete state; + return nullptr; + } + + { + const size_t memory_size = wsp_ggml_nbytes(state->decoders[0].kv_self.k) + wsp_ggml_nbytes(state->decoders[0].kv_self.v); + log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { + log("%s: kv_cache_init() failed for cross-attention cache\n", __func__); + delete state; + return nullptr; + } + + { + const size_t memory_size = wsp_ggml_nbytes(state->kv_cross.k) + wsp_ggml_nbytes(state->kv_cross.v); + log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + +#ifdef WHISPER_USE_COREML + const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); + + log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); + log("%s: first run on a device may take a while ...\n", __func__); + + state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); + if (!state->ctx_coreml) { + log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); +#ifndef WHISPER_COREML_ALLOW_FALLBACK + return nullptr; +#endif + } else { + log("%s: Core ML model loaded\n", __func__); + } +#endif + + state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); + + state->logits_id.reserve(ctx->model.hparams.n_vocab); + + // TAGS: WHISPER_DECODER_INIT + state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); + + state->decoders[0].probs.reserve(ctx->vocab.n_vocab); + state->decoders[0].logits.reserve(ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); + state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type))); + + state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type)); + state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type)); + state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type)); + state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type)); + + state->rng = std::mt19937(0); + + return state; +} + +int whisper_ctx_init_openvino_encoder( + struct whisper_context * ctx, + const char * model_path, + const char * device, + const char * cache_dir) { +#ifndef WHISPER_USE_OPENVINO + (void)(ctx); + (void)(model_path); + (void)(device); + (void)(cache_dir); + + return 1; +#else + if (!model_path && ctx->path_model.empty()) { + log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); + return 1; + } + + std::string path_encoder; + if (!model_path) { + //if model_path is not set, attempt to find it in the same directory as ggml-.bin model + path_encoder = whisper_openvino_get_path_encoder(ctx->path_model); + } else { + path_encoder = model_path; + } + + std::string path_cache; + if (!cache_dir) { + //if cache_dir is not set, set it as a dir residing next to ggml-.bin + path_cache = whisper_openvino_get_path_cache(ctx->path_model); + } else { + path_cache = cache_dir; + } + + log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); + log("%s: first run on a device may take a while ...\n", __func__); + + ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str()); + if (!ctx->state->ctx_openvino) { + log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); + return 1; + } else { + log("%s: OpenVINO model loaded\n", __func__); + } + + return 0; +#endif +} + +struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { + + log("%s: loading model from '%s'\n", __func__, path_model); + + auto fin = std::ifstream(path_model, std::ios::binary); + if (!fin) { + log("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + whisper_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + auto ctx = whisper_init_no_state(&loader); + + if (ctx) { + ctx->path_model = path_model; + } + + return ctx; +} + +struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; + + log("%s: loading model from buffer\n", __func__); + + whisper_model_loader loader = {}; + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return whisper_init_no_state(&loader); +} + +struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) { + wsp_ggml_time_init(); + + whisper_context * ctx = new whisper_context; + + if (!whisper_model_load(loader, *ctx)) { + loader->close(loader->context); + log("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + return ctx; +} + +struct whisper_context * whisper_init_from_file(const char * path_model) { + whisper_context * ctx = whisper_init_from_file_no_state(path_model); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { + whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init(struct whisper_model_loader * loader) { + whisper_context * ctx = whisper_init_no_state(loader); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +void whisper_free_state(struct whisper_state * state) +{ + if (state) { + kv_cache_free(state->kv_cross); + + for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { + kv_cache_free(state->decoders[i].kv_self); + } + +#ifdef WHISPER_USE_COREML + if (state->ctx_coreml != nullptr) { + whisper_coreml_free(state->ctx_coreml); + state->ctx_coreml = nullptr; + } +#endif + +#ifdef WHISPER_USE_OPENVINO + if (state->ctx_openvino != nullptr) { + whisper_openvino_free(state->ctx_openvino); + state->ctx_openvino = nullptr; + } +#endif + + delete state; + } +} + +void whisper_free(struct whisper_context * ctx) { + if (ctx) { + if (ctx->model.ctx) { + wsp_ggml_free(ctx->model.ctx); + } + if (ctx->model.buf) { + delete ctx->model.buf; + } + + whisper_free_state(ctx->state); + + delete ctx; + } +} + +void whisper_free_params(struct whisper_full_params * params) { + if (params) { + delete params; + } +} + +int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { + log("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) +int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { + log("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) +int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2 +// TODO + +// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2 +// TODO + +// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2 +// TODO + +int whisper_set_mel_with_state( + struct whisper_context * /*ctx*/, + struct whisper_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != WHISPER_N_MEL) { + log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); + return -1; + } + + state->mel.n_len = n_len; + state->mel.n_len_org = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel) { + return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { + if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) { + log("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { + if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) { + log("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + const int selected_decoder_id = 0; + + if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + log("%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + +int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + // TODO: add selected_decoder_id to state + const int selected_decoder_id = 0; + + if (ctx->state == nullptr) { + log("%s: ERROR state was not loaded.\n", __func__); + return false; + } + + if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + log("%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + +int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -1; + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int whisper_lang_max_id() { + auto max_id = 0; + for (const auto & kv : g_lang) { + max_id = std::max(max_id, kv.second.first); + } + + return max_id; +} + +int whisper_lang_id(const char * lang) { + if (!g_lang.count(lang)) { + for (const auto & kv : g_lang) { + if (kv.second.second == lang) { + return kv.second.first; + } + } + + log("%s: unknown language '%s'\n", __func__, lang); + return -1; + } + return g_lang.at(lang).first; +} + +const char * whisper_lang_str(int id) { + for (const auto & kv : g_lang) { + if (kv.second.first == id) { + return kv.first.c_str(); + } + } + + log("%s: unknown language id %d\n", __func__, id); + return nullptr; +} + +int whisper_lang_auto_detect_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs) { + const int seek = offset_ms/10; + + if (seek < 0) { + log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + return -1; + } + + if (seek >= state->mel.n_len_org) { + log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); + return -2; + } + + // run the encoder + if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) { + log("%s: failed to encode\n", __func__); + return -6; + } + + const std::vector prompt = { whisper_token_sot(ctx) }; + + if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { + log("%s: failed to decode\n", __func__); + return -7; + } + + auto & logits_id = state->logits_id; + logits_id.clear(); + + for (const auto & kv : g_lang) { + const auto token_lang = whisper_token_lang(ctx, kv.second.first); + logits_id.emplace_back(state->logits[token_lang], kv.second.first); + } + + // sort descending + { + using pair_type = std::remove_reference::type::value_type; + std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // softmax + { + const auto max = logits_id[0].first; + + double sum = 0.0f; + for (auto & kv : logits_id) { + kv.first = exp(kv.first - max); + sum += kv.first; + } + + for (auto & kv : logits_id) { + kv.first /= sum; + } + } + + { + for (const auto & prob : logits_id) { + if (lang_probs) { + lang_probs[prob.second] = prob.first; + } + + //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first); + } + } + + return logits_id[0].second; +} + +int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs) { + return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs); +} + +int whisper_model_n_vocab(struct whisper_context * ctx) { + return ctx->model.hparams.n_vocab; +} + +int whisper_model_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int whisper_model_n_audio_state(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_state; +} + +int whisper_model_n_audio_head(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_head; +} + +int whisper_model_n_audio_layer(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_layer; +} + +int whisper_model_n_text_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_ctx; +} + +int whisper_model_n_text_state(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_state; +} + +int whisper_model_n_text_head(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_head; +} + +int whisper_model_n_text_layer(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_layer; +} + +int whisper_model_n_mels(struct whisper_context * ctx) { + return ctx->model.hparams.n_mels; +} + +int whisper_model_ftype(struct whisper_context * ctx) { + return ctx->model.hparams.ftype; +} + +int whisper_model_type(struct whisper_context * ctx) { + return ctx->model.type; +} + +const char *whisper_model_type_readable(struct whisper_context * ctx) { + switch (ctx->model.type) { + case e_model::MODEL_TINY: + return "tiny"; + case e_model::MODEL_BASE: + return "base"; + case e_model::MODEL_SMALL: + return "small"; + case e_model::MODEL_MEDIUM: + return "medium"; + case e_model::MODEL_LARGE: + return "large"; + default: + return "unknown"; + } +} + +int whisper_n_len_from_state(struct whisper_state * state) { + return state->mel.n_len_org; +} + +int whisper_n_len(struct whisper_context * ctx) { + return ctx->state->mel.n_len_org; +} + +int whisper_n_vocab(struct whisper_context * ctx) { + return ctx->vocab.n_vocab; +} + +int whisper_n_text_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_ctx; +} + +int whisper_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int whisper_is_multilingual(struct whisper_context * ctx) { + return ctx->vocab.is_multilingual() ? 1 : 0; +} + +float * whisper_get_logits(struct whisper_context * ctx) { + return ctx->state->logits.data(); +} + +float * whisper_get_logits_from_state(struct whisper_state * state) { + return state->logits.data(); +} + +const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +whisper_token whisper_token_eot(struct whisper_context * ctx) { + return ctx->vocab.token_eot; +} + +whisper_token whisper_token_sot(struct whisper_context * ctx) { + return ctx->vocab.token_sot; +} + +whisper_token whisper_token_solm(struct whisper_context * ctx) { + return ctx->vocab.token_solm; +} + +whisper_token whisper_token_prev(struct whisper_context * ctx) { + return ctx->vocab.token_prev; +} + +whisper_token whisper_token_nosp(struct whisper_context * ctx) { + return ctx->vocab.token_nosp; +} + +whisper_token whisper_token_not(struct whisper_context * ctx) { + return ctx->vocab.token_not; +} + +whisper_token whisper_token_beg(struct whisper_context * ctx) { + return ctx->vocab.token_beg; +} + +whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { + return whisper_token_sot(ctx) + 1 + lang_id; +} + +whisper_token whisper_token_translate(struct whisper_context * ctx) { + return ctx->vocab.token_translate; +} + +whisper_token whisper_token_transcribe(struct whisper_context * ctx) { + return ctx->vocab.token_transcribe; +} + +void whisper_print_timings(struct whisper_context * ctx) { + const int64_t t_end_us = wsp_ggml_time_us(); + + log("\n"); + log("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + + log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + } + log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void whisper_reset_timings(struct whisper_context * ctx) { + if (ctx->state != nullptr) { + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + } +} + +static int whisper_has_coreml(void) { +#ifdef WHISPER_USE_COREML + return 1; +#else + return 0; +#endif +} + +static int whisper_has_openvino(void) { +#ifdef WHISPER_USE_OPENVINO + return 1; +#else + return 0; +#endif +} + +const char * whisper_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(wsp_ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(wsp_ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(wsp_ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(wsp_ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(wsp_ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(wsp_ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(wsp_ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(wsp_ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(wsp_ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(wsp_ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | "; + s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | "; + s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | "; + s += "COREML = " + std::to_string(whisper_has_coreml()) + " | "; + s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | "; + + return s.c_str(); +} + +//////////////////////////////////////////////////////////////////////////// + +struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) { + struct whisper_full_params params = whisper_full_default_params(strategy); + + struct whisper_full_params* result = new whisper_full_params(); + *result = params; + return result; +} + +struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { + struct whisper_full_params result = { + /*.strategy =*/ strategy, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ true, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.split_on_word =*/ false, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.debug_mode =*/ false, + /*.audio_ctx =*/ 0, + + /*.tdrz_enable =*/ false, + + /*.initial_prompt =*/ nullptr, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + /*.detect_language =*/ false, + + /*.suppress_blank =*/ true, + /*.suppress_non_speech_tokens =*/ false, + + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, + + /*.temperature_inc =*/ 0.4f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, + + /*.greedy =*/ { + /*.best_of =*/ -1, + }, + + /*.beam_search =*/ { + /*.beam_size =*/ -1, + + /*.patience =*/ -1.0f, + }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + + /*.logits_filter_callback =*/ nullptr, + /*.logits_filter_callback_user_data =*/ nullptr, + }; + + switch (strategy) { + case WHISPER_SAMPLING_GREEDY: + { + result.greedy = { + /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding + }; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + result.beam_search = { + /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding + + /*.patience =*/ -1.0f, + }; + } break; + } + + return result; +} + +// forward declarations +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context & ctx, + struct whisper_state & state, + int i_segment, + float thold_pt, + float thold_ptsum); + +static inline bool should_split_on_word(const char * txt, bool split_on_word) { + if (!split_on_word) return true; + + return txt[0] == ' '; +} + +// wrap the last segment to max_len characters +// returns the number of new segments +static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) { + auto segment = state.result_all.back(); + + int res = 1; + int acc = 0; + + std::string text; + + for (int i = 0; i < (int) segment.tokens.size(); i++) { + const auto & token = segment.tokens[i]; + if (token.id >= whisper_token_eot(&ctx)) { + continue; + } + + const auto txt = whisper_token_to_str(&ctx, token.id); + const int cur = strlen(txt); + + if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) { + state.result_all.back().text = std::move(text); + state.result_all.back().t1 = token.t0; + state.result_all.back().tokens.resize(i); + state.result_all.back().speaker_turn_next = false; + + state.result_all.push_back({}); + state.result_all.back().t0 = token.t0; + state.result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + state.result_all.back().tokens.insert( + state.result_all.back().tokens.end(), + segment.tokens.begin() + i, + segment.tokens.end()); + + state.result_all.back().speaker_turn_next = segment.speaker_turn_next; + + acc = 0; + text = ""; + + segment = state.result_all.back(); + i = -1; + + res++; + } else { + acc += cur; + text += txt; + } + } + + state.result_all.back().text = std::move(text); + + return res; +} + +static const std::vector non_speech_tokens = { + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", + "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", + "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" +}; + +// process the logits for the selected decoder +// - applies logit filters +// - computes logprobs and probs +static void whisper_process_logits( + struct whisper_context & ctx, + struct whisper_state & state, + const struct whisper_full_params params, + struct whisper_decoder & decoder, + float temperature) { + const auto & vocab = ctx.vocab; + const auto & tokens_cur = decoder.sequence.tokens; + + const bool is_initial = tokens_cur.size() == 0; + const int n_logits = vocab.id_to_token.size(); + + WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); + + // extract the logits for the last token + // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly + auto & probs = decoder.probs; + auto & logits = decoder.logits; + auto & logprobs = decoder.logprobs; + { + logits.resize(n_logits); + memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float)); + + if (temperature > 0.0f) { + for (int i = 0; i < n_logits; i++) { + logits[i] /= temperature; + } + } + + // will be populated a bit later + probs.resize(n_logits); + logprobs.resize(n_logits); + } + + // apply logit filters here + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 + { + // suppress blank + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 + if (params.suppress_blank) { + if (is_initial) { + logits[vocab.token_eot] = -INFINITY; + logits[vocab.token_to_id.at(" ")] = -INFINITY; + } + } + + // suppress <|notimestamps|> token + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 + logits[vocab.token_not] = -INFINITY; + + // suppress sot and nosp tokens + logits[vocab.token_sot] = -INFINITY; + logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now + + // [TDRZ] when tinydiarize is disabled, suppress solm token + if (params.tdrz_enable == false) { + logits[vocab.token_solm] = -INFINITY; + } + + // suppress task tokens + logits[vocab.token_translate] = -INFINITY; + logits[vocab.token_transcribe] = -INFINITY; + + if (params.logits_filter_callback) { + params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); + } + + // suppress non-speech tokens + // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + if (params.suppress_non_speech_tokens) { + for (const std::string & token : non_speech_tokens) { + const std::string suppress_tokens[] = {token, " " + token}; + for (const std::string & suppress_token : suppress_tokens) { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) { + logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + } + } + } + + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) { + logits[vocab.token_to_id.at(" -")] = -INFINITY; + } + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) { + logits[vocab.token_to_id.at(" '")] = -INFINITY; + } + } + + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 + { + const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; + const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + + //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } else { + for (int i = 0; i < vocab.token_eot; ++i) { + logits[i] = -INFINITY; + } + } + } + } + + // the initial timestamp cannot be larger than max_initial_ts + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial && params.max_initial_ts > 0.0f) { + const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; + const int tid0 = std::round(params.max_initial_ts/precision); + + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } + + // condition timestamp tokens to be increasing + // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556 + if (decoder.has_ts) { + const int tid0 = decoder.seek_delta/2; + + for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) { + logits[i] = -INFINITY; + } + } + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } + + // if sum of probability over timestamps is above any other token, sample timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 + { + // logsumexp over timestamps + float timestamp_logprob = -INFINITY; + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); + for (int i = vocab.token_beg; i < n_logits; ++i) { + if (logprobs[i] > -INFINITY) { + logsumexp += expf(logprobs[i] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); + + //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + + if (timestamp_logprob > max_text_token_logprob) { + for (int i = 0; i < vocab.token_beg; ++i) { + logits[i] = -INFINITY; + logprobs[i] = -INFINITY; + } + } + } + } + + // compute probs + { + for (int i = 0; i < n_logits; ++i) { + if (logits[i] == -INFINITY) { + probs[i] = 0.0f; + } else { + probs[i] = expf(logprobs[i]); + } + } + } + +#if 0 + // print first 100 logits - token string : logit + for (int i = 0; i < 100; i++) { + const auto token = vocab.id_to_token.at(i); + const auto prob = probs[i]; + const auto logit = logits[i]; + const auto logprob = logprobs[i]; + printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + } + + // "And", "and", " And", " and" + printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); +#endif +} + +static whisper_token_data whisper_sample_token( + whisper_context & ctx, + whisper_state & state, + const whisper_decoder & decoder, + bool best) { + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; + + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + result.tid = i; + } + } + + result.pt = max_ts/(sum_ts + 1e-10); + result.ptsum = sum_ts; + } + + if (best) { + for (int i = 0; i < n_logits; ++i) { + if (result.p < probs[i]) { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } + } + } else { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + result.id = dist(state.rng); + result.p = probs[result.id]; + result.plog = logprobs[result.id]; + } + + if (result.id >= vocab.token_beg) { + result.tid = result.id; + result.pt = result.p; + } + + state.n_sample++; + + return result; +} + +static std::vector whisper_sample_token_topk( + whisper_context & ctx, + whisper_state & state, + const whisper_decoder & decoder, + int k) { + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logits = decoder.logits; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + auto & logits_id = state.logits_id; + + logits_id.clear(); + for (int i = 0; i < n_logits; ++i) { + logits_id.push_back({ logits[i], i }); + } + + std::partial_sort( + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + std::vector result; + result.reserve(k); + + whisper_token tid = vocab.token_beg; + + float pt = 0.0; + float ptsum = 0.0; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + tid = i; + } + } + + pt = max_ts/(sum_ts + 1e-10); + ptsum = sum_ts; + } + + for (int i = 0; i < k; ++i) { + const auto id = logits_id[i].second; + + result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); + + if (result[i].id >= vocab.token_beg) { + result[i].tid = result[i].id; + result[i].pt = result[i].p; + } + } + + state.n_sample++; + + return result; +} + +// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 +static void whisper_sequence_score( + const struct whisper_full_params & params, + whisper_sequence & sequence) { + if (sequence.result_len == 0) { + return; + } + + double result = 0.0f; + + for (int i = 0; i < sequence.result_len; ++i) { + result += sequence.tokens[i].plog; + } + + sequence.sum_logprobs = result; + sequence.avg_logprobs = result/sequence.result_len; + + double penalty = sequence.result_len; + + if (params.length_penalty > 0.0f) { + penalty = pow((5.0 + penalty)/6.0, params.length_penalty); + } + + sequence.score = result/penalty; + + // compute the entropy of the sequence of the last 32 tokens + { + const int n = 32; + + int cnt = 0; + double entropy = 0.0f; + + std::map token_counts; + for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) { + token_counts[sequence.tokens[i].id]++; + cnt++; + } + + for (const auto & kv : token_counts) { + const auto p = kv.second/(double)cnt; + entropy -= p*log(p); + + //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + } + + sequence.entropy = entropy; + } +} + +int whisper_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples) { + // clear old results + auto & result_all = state->result_all; + + result_all.clear(); + + if (n_samples > 0) { + // compute log mel spectrogram + if (params.speed_up) { + // TODO: Replace PV with more advanced algorithm + log("%s: failed to compute log mel spectrogram\n", __func__); + return -1; + } else { + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + log("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + } + + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) { + std::vector probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); + if (lang_id < 0) { + log("%s: failed to auto-detect language\n", __func__); + return -3; + } + state->lang_id = lang_id; + params.language = whisper_lang_str(lang_id); + + log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + if (params.detect_language) { + return 0; + } + } + + if (params.token_timestamps) { + state->t_beg = 0; + state->t_last = 0; + state->tid_last = 0; + if (n_samples > 0) { + state->energy = get_signal_energy(samples, n_samples, 32); + } + } + + const int seek_start = params.offset_ms/10; + const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10; + + // if length of spectrogram is less than 1.0s (100 frames), then return + // basically don't process anything that is less than 1.0s + // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 + if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { + return 0; + } + + // a set of temperatures to use + // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] + std::vector temperatures; + if (params.temperature_inc > 0.0f) { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) { + temperatures.push_back(t); + } + } else { + temperatures.push_back(params.temperature); + } + + // initialize the decoders + int n_decoders = 1; + + switch (params.strategy) { + case WHISPER_SAMPLING_GREEDY: + { + n_decoders = params.greedy.best_of; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); + } break; + }; + + n_decoders = std::max(1, n_decoders); + + // TAGS: WHISPER_DECODER_INIT + for (int j = 1; j < n_decoders; j++) { + auto & decoder = state->decoders[j]; + + if (decoder.kv_self.ctx == nullptr) { + decoder.kv_self = state->decoders[0].kv_self; + if (!kv_cache_reinit(decoder.kv_self)) { + log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); + return -4; + } + + WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); + + decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); + + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); + } + } + + // the accumulated text context so far + auto & prompt_past = state->prompt_past; + if (params.no_context) { + prompt_past.clear(); + } + + // prepare prompt + { + std::vector prompt_tokens; + + // initial prompt + if (!params.prompt_tokens && params.initial_prompt) { + prompt_tokens.resize(1024); + prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size())); + params.prompt_tokens = prompt_tokens.data(); + params.prompt_n_tokens = prompt_tokens.size(); + } + + // prepend the prompt tokens to the prompt_past + if (params.prompt_tokens && params.prompt_n_tokens > 0) { + // parse tokens from the pointer + for (int i = 0; i < params.prompt_n_tokens; i++) { + prompt_past.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); + } + } + + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { + log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -5; + } + state->exp_n_audio_ctx = params.audio_ctx; + + // these tokens determine the task that will be performed + std::vector prompt_init = { whisper_token_sot(ctx) }; + if (whisper_is_multilingual(ctx)) { + const int lang_id = whisper_lang_id(params.language); + state->lang_id = lang_id; + prompt_init.push_back(whisper_token_lang(ctx, lang_id)); + if (params.translate) { + prompt_init.push_back(whisper_token_translate(ctx)); + } else { + prompt_init.push_back(whisper_token_transcribe(ctx)); + } + } + + int seek = seek_start; + + std::vector prompt; + prompt.reserve(whisper_n_text_ctx(ctx)); + + // beam-search helpers + struct kv_buf { + std::vector k; + std::vector v; + }; + + std::vector kv_bufs; + + struct beam_candidate { + int decoder_idx; + int seek_delta; + + bool has_ts; + + whisper_sequence sequence; + }; + + std::vector beam_candidates; + + // main loop + while (true) { + if (params.progress_callback) { + const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); + + params.progress_callback( + ctx, ctx->state, progress_cur, params.progress_callback_user_data); + } + + // of only 1 second left, then stop + if (seek + 100 >= seek_end) { + break; + } + + if (params.encoder_begin_callback) { + if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { + log("%s: encoder_begin_callback returned false - aborting\n", __func__); + break; + } + } + + // encode audio features starting at offset seek + if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) { + log("%s: failed to encode\n", __func__); + return -6; + } + + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) { + prompt_past.clear(); + } + + int best_decoder_id = 0; + + for (int it = 0; it < (int) temperatures.size(); ++it) { + const float t_cur = temperatures[it]; + + int n_decoders_cur = 1; + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } else { + n_decoders_cur = params.beam_search.beam_size; + } + } break; + }; + + n_decoders_cur = std::max(1, n_decoders_cur); + + WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); + + // TAGS: WHISPER_DECODER_INIT + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + decoder.kv_self.n = 0; + + decoder.sequence.tokens.clear(); + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs_all = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0; + decoder.sequence.score = -INFINITY; + + decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; + + decoder.failed = false; + decoder.completed = false; + decoder.has_ts = false; + } + + // init prompt and kv cache for the current iteration + // run whisper_decoder() only for decoder 0 and copy the results for the other decoders + { + prompt.clear(); + + // if we have already generated some text, use it as a prompt to condition the next generation + if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + + prompt = { whisper_token_prev(ctx) }; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + } + + // init new transcription with sot, language (opt) and task tokens + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + + // print the prompt + WHISPER_PRINT_DEBUG("\n\n"); + for (int i = 0; i < (int) prompt.size(); i++) { + WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + } + WHISPER_PRINT_DEBUG("\n\n"); + + if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { + log("%s: failed to decode\n", __func__); + return -7; + } + + { + const int64_t t_start_sample_us = wsp_ggml_time_us(); + + whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); + + state->decoders[0].kv_self.n += prompt.size(); + + for (int j = 1; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, wsp_ggml_nbytes(decoder.kv_self.k)); + memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, wsp_ggml_nbytes(decoder.kv_self.v)); + + decoder.kv_self.n += prompt.size(); + + memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + } + + state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us; + } + } + + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + const int64_t t_start_sample_us = wsp_ggml_time_us(); + + // store the KV caches of all decoders when doing beam-search + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + kv_bufs.resize(n_decoders_cur); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + kv_bufs[j].k.resize(wsp_ggml_nbytes(decoder.kv_self.k)); + kv_bufs[j].v.resize(wsp_ggml_nbytes(decoder.kv_self.v)); + + memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size()); + memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size()); + } + + beam_candidates.clear(); + } + + // generate new sequence candidates for each decoder + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true)); + } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); + beam_candidates.back().sequence.tokens.push_back(token); + beam_candidates.back().sequence.sum_logprobs_all += token.plog; + + //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); + } + } break; + }; + } + + // for beam-search, choose the top candidates and update the KV caches + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + std::sort( + beam_candidates.begin(), + beam_candidates.end(), + [](const beam_candidate & a, const beam_candidate & b) { + return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; + }); + + uint32_t cur_c = 0; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & cur = beam_candidates[cur_c++]; + + while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { + ++cur_c; + } + + decoder.sequence = cur.sequence; + decoder.seek_delta = cur.seek_delta; + decoder.has_ts = cur.has_ts; + + memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size()); + memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size()); + + WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", + __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); + } + } + + // update the decoder state + // - check if the sequence is completed + // - check if the sequence is failed + // - update sliding window based on timestamp tokens + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & has_ts = decoder.has_ts; + auto & failed = decoder.failed; + auto & completed = decoder.completed; + auto & seek_delta = decoder.seek_delta; + auto & result_len = decoder.sequence.result_len; + + { + const auto & token = decoder.sequence.tokens.back(); + + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (has_ts && seek_delta > seek_delta_new && result_len < i) { + failed = true; // TODO: maybe this is not a failure ? + continue; + } + + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + +#ifdef WHISPER_DEBUG + { + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + } +#endif + + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) { + if (result_len == 0) { + if (seek + seek_delta + 100 >= seek_end) { + result_len = i + 1; + } else { + failed = true; + continue; + } + } + + if (params.single_segment) { + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; + } + + completed = true; + continue; + } + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) { + seek_delta = 100*WHISPER_CHUNK_SIZE; + completed = true; + continue; + } + } + + // sometimes, the decoding can get stuck in a repetition loop + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + failed = true; + continue; + } + } + + // check if all decoders have finished (i.e. completed or failed) + { + bool completed_all = true; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + completed_all = false; + } + + if (completed_all) { + break; + } + } + + state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us; + + // obtain logits for the next token + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + decoder.tokens_tmp.resize(1); + decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; + + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); + + if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { + log("%s: failed to decode\n", __func__); + return -8; + } + + { + const int64_t t_start_sample_us = wsp_ggml_time_us(); + + whisper_process_logits(*ctx, *state, params, decoder, t_cur); + + ++decoder.kv_self.n; + + state->t_sample_us += wsp_ggml_time_us() - t_start_sample_us; + } + } + } + + // rank the resulting sequences and select the best one + { + double best_score = -INFINITY; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.failed) { + continue; + } + + decoder.sequence.tokens.resize(decoder.sequence.result_len); + whisper_sequence_score(params, decoder.sequence); + + WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); + + if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) { + WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + __func__, j, decoder.sequence.entropy, params.entropy_thold); + + decoder.failed = true; + state->n_fail_h++; + + continue; + } + + if (best_score < decoder.sequence.score) { + best_score = decoder.sequence.score; + best_decoder_id = j; + } + } + + WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); + } + + // was the decoding successful for the current temperature? + // do fallback only if: + // - we are not at the last temperature + // - we are not at the end of the audio (3 sec) + if (it != (int) temperatures.size() - 1 && + seek_end - seek > 10*WHISPER_CHUNK_SIZE) { + bool success = true; + + const auto & decoder = state->decoders[best_decoder_id]; + + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { + success = false; + state->n_fail_p++; + } + + if (success) { + //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + //} + + break; + } + } + + WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); + } + + // output results through a user-provided callback + { + const auto & best_decoder = state->decoders[best_decoder_id]; + + const auto seek_delta = best_decoder.seek_delta; + const auto result_len = best_decoder.sequence.result_len; + + const auto & tokens_cur = best_decoder.sequence.tokens; + + //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); + + // update prompt_past + prompt_past.clear(); + if (prompt.front() == whisper_token_prev(ctx)) { + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + } + + for (int i = 0; i < result_len; ++i) { + prompt_past.push_back(tokens_cur[i].id); + } + + if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { + int i0 = 0; + auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + + std::string text; + bool speaker_turn_next = false; + + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + + // [TDRZ] record if speaker turn was predicted after current segment + if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) { + speaker_turn_next = true; + } + + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { + const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + + if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); + + result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); + } + } + text = ""; + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + i++; + } + i--; + t0 = t1; + i0 = i + 1; + speaker_turn_next = false; + } + } + + if (!text.empty()) { + const auto t1 = seek + seek_delta; + + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); + } + } + } + + // update audio window + seek += seek_delta; + + WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); + } + } + + return 0; +} + +int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples) { + return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors) { + if (n_processors == 1) { + return whisper_full(ctx, params, samples, n_samples); + } + int ret = 0; + + // prepare separate states for each thread + std::vector states; + + const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; + const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + + std::vector workers(n_processors - 1); + for (int i = 0; i < n_processors - 1; ++i) { + // create a new state for each thread + states.push_back(whisper_init_state(ctx)); + + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; + const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + params_cur.progress_callback = nullptr; + params_cur.progress_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); + } + + { + auto params_cur = params; + + // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. + params_cur.print_realtime = false; + + // Run the first transformation using default state but only for the first chunk. + ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + } + + for (int i = 0; i < n_processors - 1; ++i) { + workers[i].join(); + } + + const int64_t offset_t = (int64_t) params.offset_ms/10.0; + + // combine results into result_state->result_all from all other states + for (int i = 0; i < n_processors - 1; ++i) { + auto& results_i = states[i]->result_all; + + for (auto& result : results_i) { + // correct the segment timestamp taking into account the offset + result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + + // make sure that segments are not overlapping + if (!ctx->state->result_all.empty()) { + result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); + } + + ctx->state->result_all.push_back(std::move(result)); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) { + params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); + } + } + + ctx->state->t_mel_us += states[i]->t_mel_us; + + ctx->state->t_sample_us += states[i]->t_sample_us; + ctx->state->t_encode_us += states[i]->t_encode_us; + ctx->state->t_decode_us += states[i]->t_decode_us; + + whisper_free_state(states[i]); + } + + // average the timings + ctx->state->t_mel_us /= n_processors; + ctx->state->t_sample_us /= n_processors; + ctx->state->t_encode_us /= n_processors; + ctx->state->t_decode_us /= n_processors; + + // print information about the audio boundaries + log("\n"); + log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + for (int i = 0; i < n_processors - 1; ++i) { + log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + } + log("%s: the transcription quality may be degraded near these boundaries\n", __func__); + + return ret; +} + +int whisper_full_n_segments_from_state(struct whisper_state * state) { + return state->result_all.size(); +} + +int whisper_full_n_segments(struct whisper_context * ctx) { + return ctx->state->result_all.size(); +} + +int whisper_full_lang_id_from_state(struct whisper_state * state) { + return state->lang_id; +} + +int whisper_full_lang_id(struct whisper_context * ctx) { + return ctx->state->lang_id; +} + +int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t1; +} + +int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].t1; +} + +bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].speaker_turn_next; +} + +const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); +} + +const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].text.c_str(); +} + +int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); +} + +int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; +} + +whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; +} + +struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; +} + +float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].p; +} + +// ================================================================================================= + +// +// Temporary interface needed for exposing ggml interface +// Will be removed in the future when ggml becomes a separate library +// + +WHISPER_API int whisper_bench_memcpy(int n_threads) { + fputs(whisper_bench_memcpy_str(n_threads), stderr); + return 0; +} + +WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { + static std::string s; + s = ""; + char strbuf[256]; + + wsp_ggml_time_init(); + + size_t n = 20; + size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations + + // 1GB MB array + const size_t size = arr*1024llu*1024llu; + + // single-thread + { + char * src = (char *) malloc(size); + char * dst = (char *) malloc(size); + + for (size_t i = 0; i < size; i++) src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + double sum = 0.0; + + for (size_t i = 0; i < n; i++) { + const int64_t t0 = wsp_ggml_time_us(); + + memcpy(dst, src, size); + + const int64_t t1 = wsp_ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + + src[rand() % size] = rand() % 256; + } + + snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); + s += strbuf; + + // needed to prevent the compiler from optimizing the memcpy away + { + for (size_t i = 0; i < size; i++) sum += dst[i]; + + snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum); + s += strbuf; + } + + free(src); + free(dst); + } + + return s.c_str(); +} + +WHISPER_API int whisper_bench_wsp_ggml_mul_mat(int n_threads) { + fputs(whisper_bench_wsp_ggml_mul_mat_str(n_threads), stderr); + return 0; +} + +WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) { + static std::string s; + s = ""; + char strbuf[256]; + + wsp_ggml_time_init(); + + const int n_max = 128; + + const std::vector sizes = { + 64, 128, 256, 512, 1024, 2048, 4096, + }; + + const size_t N_max = sizes.back(); + + // a: N*N*sizeof(float) + // b: N*N*sizeof(float) + // c: N*N*sizeof(float) + // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) + std::vector buf(4llu*N_max*N_max*sizeof(float) + 4*512); + + // put a bunch of random data in the buffer + for (size_t i = 0; i < buf.size(); i++) buf[i] = i; + + for (int j = 0; j < (int) sizes.size(); j++) { + int n_q4_0 = 0; + int n_q4_1 = 0; + int n_q5_0 = 0; + int n_q5_1 = 0; + int n_q8_0 = 0; + int n_fp16 = 0; + int n_fp32 = 0; + + // GFLOPS/s + double s_q4_0 = 0.0; + double s_q4_1 = 0.0; + double s_q5_0 = 0.0; + double s_q5_1 = 0.0; + double s_q8_0 = 0.0; + double s_fp16 = 0.0; + double s_fp32 = 0.0; + + const size_t N = sizes[j]; + + for (int k = 0; k < 7; ++k) { + const wsp_ggml_type wtype = + k == 0 ? WSP_GGML_TYPE_Q4_0 : + k == 1 ? WSP_GGML_TYPE_Q4_1 : + k == 2 ? WSP_GGML_TYPE_Q5_0 : + k == 3 ? WSP_GGML_TYPE_Q5_1 : + k == 4 ? WSP_GGML_TYPE_Q8_0 : + k == 5 ? WSP_GGML_TYPE_F16 : WSP_GGML_TYPE_F32; + + double & s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 : k == 2 ? s_q5_0 : k == 3 ? s_q5_1 : k == 4 ? s_q8_0 : k == 5 ? s_fp16 : /*k == 6*/ s_fp32; + int & n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 : k == 2 ? n_q5_0 : k == 3 ? n_q5_1 : k == 4 ? n_q8_0 : k == 5 ? n_fp16 : /*k == 6*/ n_fp32; + + struct wsp_ggml_init_params gparams = { + /*.mem_size =*/ buf.size(), + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ false, + }; + + struct wsp_ggml_context * ctx0 = wsp_ggml_init(gparams); + + struct wsp_ggml_tensor * a = wsp_ggml_new_tensor_2d(ctx0, wtype, N, N); + struct wsp_ggml_tensor * b = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, N, N); + + struct wsp_ggml_tensor * c = wsp_ggml_mul_mat(ctx0, a, b); + + struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(c); + + gf.n_threads = n_threads; + + double tsum = 0.0; + + // heat-up + wsp_ggml_graph_compute(ctx0, &gf); + + for (int i = 0; i < n_max; ++i) { + const int64_t t0 = wsp_ggml_time_us(); + + wsp_ggml_graph_compute(ctx0, &gf); + + const int64_t t1 = wsp_ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + n++; + + if (tsum > 1.0 && n >= 3) { + break; + } + } + + wsp_ggml_free(ctx0); + + s = ((2.0*N*N*N*n)/tsum)*1e-9; + } + + // Q4_0 | Q4_1 + snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q4_0 %7.1f GFLOPS (%3d runs) | Q4_1 %7.1f GFLOPS (%3d runs)\n", + N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1); + s += strbuf; + + // Q5_0 | Q5_1 | Q8_0 + snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q5_0 %7.1f GFLOPS (%3d runs) | Q5_1 %7.1f GFLOPS (%3d runs) | Q8_0 %7.1f GFLOPS (%3d runs)\n", + N, N, s_q5_0, n_q5_0, s_q5_1, n_q5_1, s_q8_0, n_q8_0); + s += strbuf; + + // F16 | F32 + snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: F16 %7.1f GFLOPS (%3d runs) | F32 %7.1f GFLOPS (%3d runs)\n", + N, N, s_fp16, n_fp16, s_fp32, n_fp32); + s += strbuf; + } + + return s.c_str(); +} + +// ================================================================================================= + +// ================================================================================================= + +// +// Experimental stuff below +// +// Not sure if these should be part of the library at all, because the quality of the results is not +// guaranteed. Might get removed at some point unless a robust algorithm implementation is found +// + +// ================================================================================================= + +// +// token-level timestamps +// + +static int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + +static int64_t sample_to_timestamp(int i_sample) { + return (100ll*i_sample)/WHISPER_SAMPLE_RATE; +} + +// a cost-function / heuristic that is high for text that takes longer to pronounce +// obviously, can be improved +static float voice_length(const std::string & text) { + float res = 0.0f; + + for (char c : text) { + if (c == ' ') { + res += 0.01f; + } else if (c == ',') { + res += 2.00f; + } else if (c == '.') { + res += 3.00f; + } else if (c == '!') { + res += 3.00f; + } else if (c == '?') { + res += 3.00f; + } else if (c >= '0' && c <= '9') { + res += 3.00f; + } else { + res += 1.00f; + } + } + + return res; +} + +// average the fabs of the signal +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { + const int hw = n_samples_per_half_window; + + std::vector result(n_samples); + + for (int i = 0; i < n_samples; i++) { + float sum = 0; + for (int j = -hw; j <= hw; j++) { + if (i + j >= 0 && i + j < n_samples) { + sum += fabs(signal[i + j]); + } + } + result[i] = sum/(2*hw + 1); + } + + return result; +} + +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context & ctx, + struct whisper_state & state, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = state.result_all[i_segment]; + auto & tokens = segment.tokens; + + const int n_samples = state.energy.size(); + + if (n_samples == 0) { + log("%s: no signal data available\n", __func__); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + + const int n = tokens.size(); + + if (n == 0) { + return; + } + + if (n == 1) { + tokens[0].t0 = t0; + tokens[0].t1 = t1; + + return; + } + + auto & t_beg = state.t_beg; + auto & t_last = state.t_last; + auto & tid_last = state.tid_last; + + for (int j = 0; j < n; ++j) { + auto & token = tokens[j]; + + if (j == 0) { + if (token.id == whisper_token_beg(&ctx)) { + tokens[j ].t0 = t0; + tokens[j ].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(&ctx); + } else { + tokens[j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); + + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { + if (j > 0) { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while (true) { + while (p1 < n && tokens[p1].t1 < 0) { + p1++; + } + + if (p1 >= n) { + p1--; + } + + //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); + + if (p1 > p0) { + double psum = 0.0; + for (int j = p0; j <= p1; j++) { + psum += tokens[j].vlen; + } + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + // split the time proportionally to the voice length + for (int j = p0 + 1; j <= p1; j++) { + const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + + tokens[j - 1].t1 = ct; + tokens[j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) { + break; + } + } + } + + // fix up (just in case) + for (int j = 0; j < n - 1; j++) { + if (tokens[j].t1 < 0) { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) { + if (tokens[j - 1].t1 > tokens[j].t0) { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + const int hw = WHISPER_SAMPLE_RATE/8; + + for (int j = 0; j < n; j++) { + if (tokens[j].id >= whisper_token_eot(&ctx)) { + continue; + } + + int s0 = timestamp_to_sample(tokens[j].t0, n_samples); + int s1 = timestamp_to_sample(tokens[j].t1, n_samples); + + const int ss0 = std::max(s0 - hw, 0); + const int ss1 = std::min(s1 + hw, n_samples); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) { + sum += state.energy[k]; + } + + const float thold = 0.5*sum/ns; + + { + int k = s0; + if (state.energy[k] > thold && j > 0) { + while (k > 0 && state.energy[k] > thold) { + k--; + } + tokens[j].t0 = sample_to_timestamp(k); + if (tokens[j].t0 < tokens[j - 1].t1) { + tokens[j].t0 = tokens[j - 1].t1; + } else { + s0 = k; + } + } else { + while (state.energy[k] < thold && k < s1) { + k++; + } + s0 = k; + tokens[j].t0 = sample_to_timestamp(k); + } + } + + { + int k = s1; + if (state.energy[k] > thold) { + while (k < n_samples - 1 && state.energy[k] > thold) { + k++; + } + tokens[j].t1 = sample_to_timestamp(k); + if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) { + tokens[j].t1 = tokens[j + 1].t0; + } else { + s1 = k; + } + } else { + while (state.energy[k] < thold && k > s0) { + k--; + } + s1 = k; + tokens[j].t1 = sample_to_timestamp(k); + } + } + } + } + + // fixed token expand (optional) + //{ + // const int t_expand = 0; + + // for (int j = 0; j < n; j++) { + // if (j > 0) { + // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + // } + // if (j < n - 1) { + // tokens[j].t1 = tokens[j].t1 + t_expand; + // } + // } + //} + + // debug info + //for (int j = 0; j < n; ++j) { + // const auto & token = tokens[j]; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; + // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id)); + + // if (tokens[j].id >= whisper_token_eot(&ctx)) { + // continue; + // } + //} +} + +void whisper_set_log_callback(whisper_log_callback callback) { + whisper_log = callback; +} diff --git a/cpp/whisper.h b/cpp/whisper.h new file mode 100644 index 0000000..a5e4936 --- /dev/null +++ b/cpp/whisper.h @@ -0,0 +1,531 @@ +#ifndef WHISPER_H +#define WHISPER_H + +#include +#include +#include + +#ifdef WHISPER_SHARED +# ifdef _WIN32 +# ifdef WHISPER_BUILD +# define WHISPER_API __declspec(dllexport) +# else +# define WHISPER_API __declspec(dllimport) +# endif +# else +# define WHISPER_API __attribute__ ((visibility ("default"))) +# endif +#else +# define WHISPER_API +#endif + +#define WHISPER_SAMPLE_RATE 16000 +#define WHISPER_N_FFT 400 +#define WHISPER_N_MEL 80 +#define WHISPER_HOP_LENGTH 160 +#define WHISPER_CHUNK_SIZE 30 + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads + // concurrently. + // + // Basic usage: + // + // #include "whisper.h" + // + // ... + // + // struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin"); + // + // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + // fprintf(stderr, "failed to process audio\n"); + // return 7; + // } + // + // const int n_segments = whisper_full_n_segments(ctx); + // for (int i = 0; i < n_segments; ++i) { + // const char * text = whisper_full_get_segment_text(ctx, i); + // printf("%s", text); + // } + // + // whisper_free(ctx); + // + // ... + // + // This is a demonstration of the most straightforward usage of the library. + // "pcmf32" contains the RAW audio data in 32-bit floating point format. + // + // The interface also allows for more fine-grained control over the computation, but it requires a deeper + // understanding of how the model works. + // + + struct whisper_context; + struct whisper_state; + struct whisper_full_params; + + typedef int whisper_token; + + typedef struct whisper_token_data { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + + float p; // probability of the token + float plog; // log probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + float vlen; // voice length of the token + } whisper_token_data; + + typedef struct whisper_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } whisper_model_loader; + + // Various functions for loading a ggml whisper model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); + WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); + WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); + + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) + WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model); + WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size); + WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader); + + WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); + + // Given a context, enable use of OpenVINO for encode inference. + // model_path: Optional path to OpenVINO encoder IR model. If set to nullptr, + // the path will be generated from the ggml model path that was passed + // in to whisper_init_from_file. For example, if 'path_model' was + // "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be + // assumed to be "/path/to/ggml-base.en-encoder-openvino.xml". + // device: OpenVINO device to run inference on ("CPU", "GPU", etc.) + // cache_dir: Optional cache directory that can speed up init time, especially for + // GPU, by caching compiled 'blobs' there. + // Set to nullptr if not used. + // Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1. + WHISPER_API int whisper_ctx_init_openvino_encoder( + struct whisper_context * ctx, + const char * model_path, + const char * device, + const char * cache_dir); + + // Frees all allocated memory + WHISPER_API void whisper_free (struct whisper_context * ctx); + WHISPER_API void whisper_free_state(struct whisper_state * state); + WHISPER_API void whisper_free_params(struct whisper_full_params * params); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the default state of the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel( + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + WHISPER_API int whisper_pcm_to_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); + + // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. + // The resulting spectrogram is stored inside the default state of the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel_phase_vocoder( + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. + // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 80 + // Returns 0 on success + WHISPER_API int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel); + + WHISPER_API int whisper_set_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context. + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + WHISPER_API int whisper_encode( + struct whisper_context * ctx, + int offset, + int n_threads); + + WHISPER_API int whisper_encode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset, + int n_threads); + + // Run the Whisper decoder to obtain the logits and probabilities for the next token. + // Make sure to call whisper_encode() first. + // tokens + n_tokens is the provided context for the decoder. + // n_past is the number of tokens to use from previous decoder calls. + // Returns 0 on success + // TODO: add support for multiple decoders + WHISPER_API int whisper_decode( + struct whisper_context * ctx, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + WHISPER_API int whisper_decode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns -1 on failure + // TODO: not sure if correct + WHISPER_API int whisper_tokenize( + struct whisper_context * ctx, + const char * text, + whisper_token * tokens, + int n_max_tokens); + + // Largest language id (i.e. number of available languages - 1) + WHISPER_API int whisper_lang_max_id(); + + // Return the id of the specified language, returns -1 if not found + // Examples: + // "de" -> 2 + // "german" -> 2 + WHISPER_API int whisper_lang_id(const char * lang); + + // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str(int id); + + // Use mel data at offset_ms to try and auto-detect the spoken language + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first + // Returns the top language id or negative on failure + // If not null, fills the lang_probs array with the probabilities of all languages + // The array must be whisper_lang_max_id() + 1 in size + // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 + WHISPER_API int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_lang_auto_detect_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length + WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length + WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx); + + WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx); + WHISPER_API int whisper_model_ftype (struct whisper_context * ctx); + WHISPER_API int whisper_model_type (struct whisper_context * ctx); + + // Token logits obtained from the last call to whisper_decode() + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + WHISPER_API float * whisper_get_logits (struct whisper_context * ctx); + WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state); + + // Token Id -> String. Uses the vocabulary in the provided context + WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); + WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx); + + + // Special tokens + WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); + + // Task tokens + WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); + + // Performance information from the default state. + WHISPER_API void whisper_print_timings(struct whisper_context * ctx); + WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + + // Print system information + WHISPER_API const char * whisper_print_system_info(void); + + //////////////////////////////////////////////////////////////////////////// + + // Available sampling strategies + enum whisper_sampling_strategy { + WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreedyDecoder + WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder + }; + + // Text segment callback + // Called on every newly generated text segment + // Use the whisper_full_...() functions to obtain the text segments + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); + + // Progress callback + typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); + + // Logits filter callback + // Can be used to modify the logits before sampling + // If not NULL, called after applying temperature to logits + typedef void (*whisper_logits_filter_callback)( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token_data * tokens, + int n_tokens, + float * logits, + void * user_data); + + // Parameters for the whisper_full() function + // If you change the order or add new parameters, make sure to update the default values in whisper.cpp: + // whisper_full_default_params() + struct whisper_full_params { + enum whisper_sampling_strategy strategy; + + int n_threads; + int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool translate; + bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool single_segment; // force single segment output (useful for streaming) + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime + + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + bool split_on_word; // split on word rather than on token (when used with max_len) + int max_tokens; // max tokens per segment (0 = no limit) + + // [EXPERIMENTAL] speed-up techniques + // note: these can significantly reduce the quality of the output + bool speed_up; // speed-up the audio by 2x using Phase Vocoder + bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) + int audio_ctx; // overwrite the audio context size (0 = use default) + + // [EXPERIMENTAL] [TDRZ] tinydiarize + bool tdrz_enable; // enable tinydiarize speaker turn detection + + // tokens to provide to the whisper decoder as initial prompt + // these are prepended to any existing text context from a previous call + const char * initial_prompt; + const whisper_token * prompt_tokens; + int prompt_n_tokens; + + // for auto-detection, set to nullptr, "" or "auto" + const char * language; + bool detect_language; + + // common decoding parameters: + bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + + float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 + float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 + + // fallback parameters + // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + float temperature_inc; + float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" + float logprob_thold; + float no_speech_thold; // TODO: not implemented + + struct { + int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 + } greedy; + + struct { + int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 + + float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf + } beam_search; + + // called for every newly generated text segment + whisper_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called on each progress update + whisper_progress_callback progress_callback; + void * progress_callback_user_data; + + // called each time before the encoder starts + whisper_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + + // called by each decoder to filter obtained logits + whisper_logits_filter_callback logits_filter_callback; + void * logits_filter_callback_user_data; + }; + + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params() + WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy); + WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context + // Uses the specified decoding strategy to obtain the text. + WHISPER_API int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples); + + WHISPER_API int whisper_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples); + + // Split the input audio in chunks and process each chunk separately using whisper_full_with_state() + // Result is stored in the default state of the context + // Not thread safe if executed in parallel on the same context. + // It seems this approach can offer some speedup in some cases. + // However, the transcription accuracy can be worse at the beginning and end of each chunk. + WHISPER_API int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors); + + // Number of generated text segments + // A segment can be a few words, a sentence, or even a paragraph. + WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx); + WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state); + + // Language id associated with the context's default state + WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + + // Language id associated with the provided state + WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state); + + // Get the start and end time of the specified segment + WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment); + + WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); + + // Get whether the next segment is predicted as a speaker turn + WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment); + + // Get the text of the specified segment + WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); + WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); + + // Get number of tokens in the specified segment + WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment); + WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment); + + // Get the token text of the specified token in the specified segment + WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token); + + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment + // This contains probabilities, timestamps, etc. + WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment + WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token); + + //////////////////////////////////////////////////////////////////////////// + + // Temporary helpers needed for exposing ggml interface + + WHISPER_API int whisper_bench_memcpy (int n_threads); + WHISPER_API const char * whisper_bench_memcpy_str (int n_threads); + WHISPER_API int whisper_bench_wsp_ggml_mul_mat (int n_threads); + WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads); + + // Control logging output; default behavior is to print to stderr + + typedef void (*whisper_log_callback)(const char * line); + WHISPER_API void whisper_set_log_callback(whisper_log_callback callback); + +#ifdef __cplusplus +} +#endif + +#endif