1+ #include  < jni.h> 
2+ #include  " diffusion_session.h" 
3+ #include  " nlohmann/json.hpp" 
4+ #include  " mls_log.h" 
5+ 
6+ using  namespace  mls ; 
7+ using  namespace  nlohmann ; 
8+ 
9+ extern  " C" 
10+ JNIEXPORT void  JNICALL
11+ Java_com_alibaba_mnnllm_android_llm_DiffusionSession_resetNative (JNIEnv *env, jobject thiz,
12+                                                                  jlong instance_id) {
13+ 
14+ }
15+ 
16+ extern  " C" 
17+ JNIEXPORT void  JNICALL
18+ Java_com_alibaba_mnnllm_android_llm_DiffusionSession_releaseNative (JNIEnv *env, jobject thiz,
19+                                jlong instance_id) {
20+     auto * diffusion = reinterpret_cast <DiffusionSession*>(instance_id);
21+     delete  diffusion;
22+ }
23+ 
24+ extern  " C" 
25+ JNIEXPORT jlong JNICALL
26+ Java_com_alibaba_mnnllm_android_llm_DiffusionSession_initNative (JNIEnv *env,
27+                                                                 jobject thiz,
28+                                                                 jstring config_path,
29+                                                                 jstring extra_config_j) {
30+     MNN_DEBUG (" DiffusionSession::initNative" 
31+     const  char * config_path_cstr = env->GetStringUTFChars (config_path, nullptr );
32+     const  char * extra_json_config_cstr = env->GetStringUTFChars (extra_config_j, nullptr );
33+     MNN_DEBUG (" DiffusionSession::initNative config_path_cstr : %s extra_json_config_cstr: %s" 
34+     json extra_json_config = json::parse (extra_json_config_cstr);
35+     std::string diffusion_memory_mode = extra_json_config[" diffusion_memory_mode" 
36+     int  diffusion_memory_mode_int = std::stoi (diffusion_memory_mode);
37+     auto  diffusion = new  DiffusionSession (config_path_cstr, diffusion_memory_mode_int);
38+     env->ReleaseStringUTFChars (extra_config_j, extra_json_config_cstr);
39+     env->ReleaseStringUTFChars (config_path, config_path_cstr);
40+     return  reinterpret_cast <jlong>(diffusion);
41+ }
42+ extern  " C" 
43+ JNIEXPORT jobject JNICALL
44+ Java_com_alibaba_mnnllm_android_llm_DiffusionSession_submitDiffusionNative (JNIEnv *env,
45+                                                                            jobject thiz,
46+                                                                            jlong instance_id,
47+                                                                            jstring input,
48+                                                                            jstring joutput_path,
49+                                                                            jint iter_num,
50+                                                                            jint random_seed,
51+                                                                            jobject progress_listener) {
52+     auto * diffusion = reinterpret_cast <DiffusionSession*>(instance_id); //  Cast back to Llm*
53+     if  (!diffusion) {
54+         return  nullptr ;
55+     }
56+     jclass progressListenerClass = env->GetObjectClass (progress_listener);
57+     jmethodID onProgressMethod = env->GetMethodID (progressListenerClass, " onProgress" " (Ljava/lang/String;)Z" 
58+     if  (!onProgressMethod) {
59+         MNN_DEBUG (" ProgressListener onProgress method not found." 
60+     }
61+     std::string prompt = env->GetStringUTFChars (input, nullptr );
62+     std::string output_path = env->GetStringUTFChars (joutput_path, nullptr );
63+     auto  start = std::chrono::high_resolution_clock::now ();
64+     diffusion->Run (prompt,
65+                    output_path,
66+                    iter_num,
67+                    random_seed,
68+                    [env, progress_listener, onProgressMethod](int  progress) {
69+                        if  (progress_listener && onProgressMethod) {
70+                            jstring javaString =  env->NewStringUTF (std::to_string (progress).c_str ());
71+                            env->CallBooleanMethod (progress_listener, onProgressMethod,  javaString);
72+                            env->DeleteLocalRef (javaString);
73+                        }
74+                    });
75+     auto  end = std::chrono::high_resolution_clock::now ();
76+     auto  duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count ();
77+     jclass hashMapClass = env->FindClass (" java/util/HashMap" 
78+     jmethodID hashMapInit = env->GetMethodID (hashMapClass, " <init>" " ()V" 
79+     jmethodID putMethod = env->GetMethodID (hashMapClass, " put" " (Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;" 
80+     jobject hashMap = env->NewObject (hashMapClass, hashMapInit);
81+     env->CallObjectMethod (hashMap, putMethod, env->NewStringUTF (" total_timeus" NewObject (env->FindClass (" java/lang/Long" GetMethodID (env->FindClass (" java/lang/Long" " <init>" " (J)V" 
82+     return  hashMap;
83+ }
0 commit comments