Skip to content

Commit

Permalink
shipit
Browse files Browse the repository at this point in the history
  • Loading branch information
szymonrybczak committed Jul 19, 2024
1 parent c31294c commit bc760dc
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 11 deletions.
129 changes: 123 additions & 6 deletions ios/Ai.mm
Original file line number Diff line number Diff line change
@@ -1,18 +1,135 @@
#import "Ai.h"
#import "MLCEngine.h"

@interface Ai ()

@property (nonatomic, strong) MLCEngine *engine;
@property (nonatomic, strong) NSURL *bundleURL;
@property (nonatomic, strong) NSString *modelPath;
@property (nonatomic, strong) NSString *modelLib;
@property (nonatomic, strong) NSString *displayText;

@end

@implementation Ai

RCT_EXPORT_MODULE()

// Example method
// See // https://reactnative.dev/docs/native-modules-ios
RCT_EXPORT_METHOD(multiply:(double)a
b:(double)b
- (instancetype)init {
self = [super init];
if (self) {
_engine = [[MLCEngine alloc] init];
_bundleURL = [[[NSBundle mainBundle] bundleURL] URLByAppendingPathComponent:@"bundle"];
_modelPath = @"Llama-3-8B-Instruct-q3f16_1-MLC";
_modelLib = @"llama_q3f16_1";
_displayText = @"";
}
return self;
}

RCT_EXPORT_METHOD(doGenerate:(NSString *)instanceId
text:(NSString *)text
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject)
{
NSNumber *result = @(a * b);
NSLog(@"Generating for instance ID: %@, with text: %@", instanceId, text);

dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
NSURL *modelLocalURL = [self.bundleURL URLByAppendingPathComponent:self.modelPath];
NSString *modelLocalPath = [modelLocalURL path];

[self.engine reloadWithModelPath:modelLocalPath modelLib:self.modelLib];

NSDictionary *message = @{
@"role": @"user",
@"content": text
};

[self.engine chatCompletionWithMessages:@[message] completion:^(id response) {
if ([response isKindOfClass:[NSDictionary class]]) {
NSDictionary *responseDictionary = (NSDictionary *)response;
if (responseDictionary[@"usage"]) {
NSString *usageText = [self getUsageTextFromExtra:responseDictionary[@"usage"][@"extra"]];
self.displayText = [self.displayText stringByAppendingFormat:@"\n%@", usageText];
resolve(self.displayText);
} else {
NSString *content = responseDictionary[@"choices"][0][@"delta"][@"content"];
if (content) {
self.displayText = [self.displayText stringByAppendingString:content];
}
}
} else if ([response isKindOfClass:[NSString class]]) {
self.displayText = [self.displayText stringByAppendingString:(NSString *)response];
}
}];
});
}

RCT_EXPORT_METHOD(doStream:(NSString *)instanceId
text:(NSString *)text
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject)
{
NSLog(@"Streaming for instance ID: %@, with text: %@", instanceId, text);

dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
NSURL *modelLocalURL = [self.bundleURL URLByAppendingPathComponent:self.modelPath];
NSString *modelLocalPath = [modelLocalURL path];

[self.engine reloadWithModelPath:modelLocalPath modelLib:self.modelLib];

NSDictionary *message = @{
@"role": @"user",
@"content": text
};

[self.engine chatCompletionWithMessages:@[message] completion:^(id response) {
if ([response isKindOfClass:[NSDictionary class]]) {
NSDictionary *responseDictionary = (NSDictionary *)response;
if (responseDictionary[@"usage"]) {
NSString *usageText = [self getUsageTextFromExtra:responseDictionary[@"usage"][@"extra"]];
self.displayText = [self.displayText stringByAppendingFormat:@"\n%@", usageText];
resolve(self.displayText);
} else {
NSString *content = responseDictionary[@"choices"][0][@"delta"][@"content"];
if (content) {
self.displayText = [self.displayText stringByAppendingString:content];
// [self sendEventWithName:@"onStreamProgress" body:@{@"text": content}];
}
}
} else if ([response isKindOfClass:[NSString class]]) {
NSString *content = (NSString *)response;
self.displayText = [self.displayText stringByAppendingString:content];
// [self sendEventWithName:@"onStreamProgress" body:@{@"text": content}];
}
}];
});
}


RCT_EXPORT_METHOD(getModel:(NSString *)name
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject)
{
NSLog(@"Getting model: %@", name);

// For now, we're just returning the model path and lib
NSDictionary *modelInfo = @{
@"path": self.modelPath,
@"lib": self.modelLib
};

resolve(modelInfo);
}

- (NSString *)getUsageTextFromExtra:(NSDictionary *)extra {
// Implement this method to convert the extra dictionary to a string
// This is a placeholder implementation
return [extra description];
}

resolve(result);
- (NSArray<NSString *> *)supportedEvents {
return @[@"onStreamProgress"];
}

// Don't compile this code when we build for the old architecture.
Expand Down
File renamed without changes.
13 changes: 13 additions & 0 deletions ios/LLMEngine.mm
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ - (instancetype)init {
if (self = [super init]) {
// load chat module
const PackedFunc* f_json_ffi_create = Registry::Get("mlc.json_ffi.CreateJSONFFIEngine");
NSLog(@"Listing all available functions in the global TVM registry:");
for (const auto& kv : ::tvm::runtime::Registry::ListNames()) {
NSLog(@"Function: %s", kv.c_str());
}

if (!f_json_ffi_create) {
NSLog(@"Error: Cannot find mlc.json_ffi.CreateJSONFFIEngine in the registry");
// You might want to list available functions in the registry for debugging
// This is just a pseudocode example, adjust according to TVM's API
for (const auto& name : Registry::ListNames()) {
NSLog(@"Available function: %s", name.c_str());
}
}
ICHECK(f_json_ffi_create) << "Cannot find mlc.json_ffi.CreateJSONFFIEngine";
json_ffi_engine_ = (*f_json_ffi_create)();
init_background_engine_func_ = json_ffi_engine_->GetFunction("init_background_engine");
Expand Down
3 changes: 1 addition & 2 deletions MLCEngine.h → ios/MLCEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ NS_ASSUME_NONNULL_BEGIN
- (void)unload;

- (void)chatCompletionWithMessages:(NSArray *)messages
completion:(void (^)(NSString *response))completion;

completion:(void (^)(id response))completion;
@end

NS_ASSUME_NONNULL_END
10 changes: 7 additions & 3 deletions src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ const Ai = AiModule
);

export async function getModel(name: string): Promise<AiModel> {

Check failure on line 34 in src/index.tsx

View workflow job for this annotation

GitHub Actions / lint

'name' is defined but never used. Allowed unused args must match /^_/u
const instanceDataJson = await Ai.getModel(name);
const instanceData: ModelInstance = JSON.parse(instanceDataJson);
return new AiModel(instanceData);
// const instanceDataJson = await Ai.getModel(name);
// console.log(instanceDataJson);
// const instanceData: ModelInstance = JSON.parse(instanceDataJson);
// return new AiModel(instanceData);
}

export interface ModelInstance {
Expand Down Expand Up @@ -101,3 +102,6 @@ class AiModel implements LanguageModelV1 {

// Add other methods here as needed
}
const { doGenerate, doStream } = Ai;

export { doGenerate, doStream };

0 comments on commit bc760dc

Please sign in to comment.