Skip to content

Commit 67709fd

Browse files
authored
Merge pull request #3504 from Juude/bugfix/input_press
refactor chat activity code and fix some bugs
2 parents 16bc909 + 17d8cb2 commit 67709fd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1394
-1706
lines changed

apps/Android/MnnLlmChat/app/src/main/cpp/CMakeLists.txt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ project("mnnllmapp")
2727
add_library(${CMAKE_PROJECT_NAME} SHARED
2828
# List C/C++ source files with relative paths to this CMakeLists.txt.
2929
llm_mnn_jni.cpp
30+
diffusion_jni.cpp
3031
diffusion_session.cpp
3132
llm_session.cpp
3233
)
@@ -84,10 +85,4 @@ target_link_libraries(${CMAKE_PROJECT_NAME}
8485
android
8586
log
8687
MNN
87-
# llm
88-
# mnn_Express
89-
# mnn_cl
90-
# mnn_audio
91-
# mnn_cv
92-
# diffusion
9388
)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#include <jni.h>
2+
#include "diffusion_session.h"
3+
#include "nlohmann/json.hpp"
4+
#include "mls_log.h"
5+
6+
using namespace mls;
7+
using namespace nlohmann;
8+
9+
extern "C"
10+
JNIEXPORT void JNICALL
11+
Java_com_alibaba_mnnllm_android_llm_DiffusionSession_resetNative(JNIEnv *env, jobject thiz,
12+
jlong instance_id) {
13+
14+
}
15+
16+
extern "C"
17+
JNIEXPORT void JNICALL
18+
Java_com_alibaba_mnnllm_android_llm_DiffusionSession_releaseNative(JNIEnv *env, jobject thiz,
19+
jlong instance_id) {
20+
auto* diffusion = reinterpret_cast<DiffusionSession*>(instance_id);
21+
delete diffusion;
22+
}
23+
24+
extern "C"
25+
JNIEXPORT jlong JNICALL
26+
Java_com_alibaba_mnnllm_android_llm_DiffusionSession_initNative(JNIEnv *env,
27+
jobject thiz,
28+
jstring config_path,
29+
jstring extra_config_j) {
30+
MNN_DEBUG("DiffusionSession::initNative");
31+
const char* config_path_cstr = env->GetStringUTFChars(config_path, nullptr);
32+
const char* extra_json_config_cstr = env->GetStringUTFChars(extra_config_j, nullptr);
33+
MNN_DEBUG("DiffusionSession::initNative config_path_cstr : %s extra_json_config_cstr: %s", config_path_cstr, extra_json_config_cstr);
34+
json extra_json_config = json::parse(extra_json_config_cstr);
35+
std::string diffusion_memory_mode = extra_json_config["diffusion_memory_mode"];
36+
int diffusion_memory_mode_int = std::stoi(diffusion_memory_mode);
37+
auto diffusion = new DiffusionSession(config_path_cstr, diffusion_memory_mode_int);
38+
env->ReleaseStringUTFChars(extra_config_j, extra_json_config_cstr);
39+
env->ReleaseStringUTFChars(config_path, config_path_cstr);
40+
return reinterpret_cast<jlong>(diffusion);
41+
}
42+
extern "C"
43+
JNIEXPORT jobject JNICALL
44+
Java_com_alibaba_mnnllm_android_llm_DiffusionSession_submitDiffusionNative(JNIEnv *env,
45+
jobject thiz,
46+
jlong instance_id,
47+
jstring input,
48+
jstring joutput_path,
49+
jint iter_num,
50+
jint random_seed,
51+
jobject progress_listener) {
52+
auto* diffusion = reinterpret_cast<DiffusionSession*>(instance_id); // Cast back to Llm*
53+
if (!diffusion) {
54+
return nullptr;
55+
}
56+
jclass progressListenerClass = env->GetObjectClass(progress_listener);
57+
jmethodID onProgressMethod = env->GetMethodID(progressListenerClass, "onProgress", "(Ljava/lang/String;)Z");
58+
if (!onProgressMethod) {
59+
MNN_DEBUG("ProgressListener onProgress method not found.");
60+
}
61+
std::string prompt = env->GetStringUTFChars(input, nullptr);
62+
std::string output_path = env->GetStringUTFChars(joutput_path, nullptr);
63+
auto start = std::chrono::high_resolution_clock::now();
64+
diffusion->Run(prompt,
65+
output_path,
66+
iter_num,
67+
random_seed,
68+
[env, progress_listener, onProgressMethod](int progress) {
69+
if (progress_listener && onProgressMethod) {
70+
jstring javaString = env->NewStringUTF(std::to_string(progress).c_str());
71+
env->CallBooleanMethod(progress_listener, onProgressMethod, javaString);
72+
env->DeleteLocalRef(javaString);
73+
}
74+
});
75+
auto end = std::chrono::high_resolution_clock::now();
76+
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
77+
jclass hashMapClass = env->FindClass("java/util/HashMap");
78+
jmethodID hashMapInit = env->GetMethodID(hashMapClass, "<init>", "()V");
79+
jmethodID putMethod = env->GetMethodID(hashMapClass, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
80+
jobject hashMap = env->NewObject(hashMapClass, hashMapInit);
81+
env->CallObjectMethod(hashMap, putMethod, env->NewStringUTF("total_timeus"), env->NewObject(env->FindClass("java/lang/Long"), env->GetMethodID(env->FindClass("java/lang/Long"), "<init>", "(J)V"), duration));
82+
return hashMap;
83+
}

apps/Android/MnnLlmChat/app/src/main/cpp/diffusion_session.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "mls_log.h"
88
#include <memory>
99
#include <utility>
10-
mls::DiffusionSession::DiffusionSession(std::string resource_path, int memory_mode):
10+
mls::DiffusionSession::DiffusionSession(std::string resource_path, int memory_mode):
1111
resource_path_(std::move(resource_path)),
1212
memory_mode_(memory_mode){
1313
this->diffusion_= std::make_unique<Diffusion>(

0 commit comments

Comments
 (0)