Skip to content

Commit 5d20545

Browse files
committed
VertexAI - Add logic for text input/output
1 parent ba46a89 commit 5d20545

File tree

6 files changed

+558
-64
lines changed

6 files changed

+558
-64
lines changed

vertexai/src/Candidate.cs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
using System;
1818
using System.Collections.Generic;
19+
using System.Collections.ObjectModel;
1920

2021
namespace Firebase.VertexAI {
2122

@@ -32,11 +33,57 @@ public enum FinishReason {
3233
MalformedFunctionCall,
3334
}
3435

36+
/// <summary>
37+
/// A struct representing a possible reply to a content generation prompt.
38+
/// Each content generation prompt may produce multiple candidate responses.
39+
/// </summary>
3540
public readonly struct Candidate {
41+
private readonly ReadOnlyCollection<SafetyRating> _safetyRatings;
42+
43+
/// <summary>
44+
/// The response’s content.
45+
/// </summary>
3646
public ModelContent Content { get; }
37-
public IEnumerable<SafetyRating> SafetyRatings { get; }
47+
48+
/// <summary>
49+
/// The safety rating of the response content.
50+
/// </summary>
51+
public IEnumerable<SafetyRating> SafetyRatings =>
52+
_safetyRatings ?? new ReadOnlyCollection<SafetyRating>(new List<SafetyRating>());
53+
54+
/// <summary>
55+
/// The reason the model stopped generating content, if it exists;
56+
/// for example, if the model generated a predefined stop sequence.
57+
/// </summary>
3858
public FinishReason? FinishReason { get; }
59+
60+
/// <summary>
61+
/// Cited works in the model’s response content, if it exists.
62+
/// </summary>
3963
public CitationMetadata? CitationMetadata { get; }
64+
65+
// Hidden constructor, users don't need to make this, though they still technically can.
66+
internal Candidate(ModelContent content, List<SafetyRating> safetyRatings,
67+
FinishReason? finishReason, CitationMetadata? citationMetadata) {
68+
Content = content;
69+
_safetyRatings = new ReadOnlyCollection<SafetyRating>(safetyRatings ?? new List<SafetyRating>());
70+
FinishReason = finishReason;
71+
CitationMetadata = citationMetadata;
72+
}
73+
74+
internal static Candidate FromJson(Dictionary<string, object> jsonDict) {
75+
ModelContent content = new();
76+
if (jsonDict.TryGetValue("content", out object contentObj)) {
77+
if (contentObj is not Dictionary<string, object> contentDict) {
78+
throw new VertexAISerializationException("Invalid JSON format: 'content' is not a dictionary.");
79+
}
80+
// We expect this to be another dictionary to convert
81+
content = ModelContent.FromJson(contentDict);
82+
}
83+
84+
// TODO: Parse SafetyRatings, FinishReason, and CitationMetadata
85+
return new Candidate(content, null, null, null);
86+
}
4087
}
4188

4289
}

vertexai/src/GenerateContentResponse.cs

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,87 @@
1515
*/
1616

1717
using System.Collections.Generic;
18+
using System.Collections.ObjectModel;
19+
using System.Linq;
20+
using Google.MiniJSON;
1821

1922
namespace Firebase.VertexAI {
2023

24+
/// <summary>
25+
/// The model's response to a generate content request.
26+
/// </summary>
2127
public readonly struct GenerateContentResponse {
22-
public IEnumerable<Candidate> Candidates { get; }
28+
private readonly ReadOnlyCollection<Candidate> _candidates;
29+
30+
/// <summary>
31+
/// A list of candidate response content, ordered from best to worst.
32+
/// </summary>
33+
public IEnumerable<Candidate> Candidates =>
34+
_candidates ?? new ReadOnlyCollection<Candidate>(new List<Candidate>());
35+
36+
/// <summary>
37+
/// A value containing the safety ratings for the response, or,
38+
/// if the request was blocked, a reason for blocking the request.
39+
/// </summary>
2340
public PromptFeedback? PromptFeedback { get; }
41+
42+
/// <summary>
43+
/// Token usage metadata for processing the generate content request.
44+
/// </summary>
2445
public UsageMetadata? UsageMetadata { get; }
2546

26-
// Helper properties
27-
// The response's content as text, if it exists
28-
public string Text { get; }
47+
/// <summary>
48+
/// The response's content as text, if it exists
49+
/// </summary>
50+
public string Text {
51+
get {
52+
// Concatenate all of the text parts from the first candidate.
53+
return string.Join(" ",
54+
Candidates.FirstOrDefault().Content.Parts
55+
.OfType<ModelContent.TextPart>().Select(tp => tp.Text));
56+
}
57+
}
58+
59+
/// <summary>
60+
/// Returns function calls found in any Parts of the first candidate of the response, if any.
61+
/// </summary>
62+
public IEnumerable<ModelContent.FunctionCallPart> FunctionCalls {
63+
get {
64+
return Candidates.FirstOrDefault().Content.Parts.OfType<ModelContent.FunctionCallPart>();
65+
}
66+
}
67+
68+
// Hidden constructor, users don't need to make this, though they still technically can.
69+
internal GenerateContentResponse(List<Candidate> candidates, PromptFeedback? promptFeedback,
70+
UsageMetadata? usageMetadata) {
71+
_candidates = new ReadOnlyCollection<Candidate>(candidates ?? new List<Candidate>());
72+
PromptFeedback = promptFeedback;
73+
UsageMetadata = usageMetadata;
74+
}
75+
76+
internal static GenerateContentResponse FromJson(string jsonString) {
77+
return FromJson(Json.Deserialize(jsonString) as Dictionary<string, object>);
78+
}
79+
80+
internal static GenerateContentResponse FromJson(Dictionary<string, object> jsonDict) {
81+
// Parse the Candidates
82+
List<Candidate> candidates = new();
83+
if (jsonDict.TryGetValue("candidates", out object candidatesObject)) {
84+
if (candidatesObject is not List<object> listOfCandidateObjects) {
85+
throw new VertexAISerializationException("Invalid JSON format: 'candidates' is not a list.");
86+
}
87+
88+
candidates = listOfCandidateObjects
89+
.Select(o => o as Dictionary<string, object>)
90+
.Where(dict => dict != null)
91+
.Select(Candidate.FromJson)
92+
.ToList();
93+
}
94+
95+
// TODO: Parse PromptFeedback and UsageMetadata
2996

30-
// Returns function calls found in any Parts of the first candidate of the response, if any.
31-
public IEnumerable<ModelContent.FunctionCallPart> FunctionCalls { get; }
97+
return new GenerateContentResponse(candidates, null, null);
98+
}
3299
}
33100

34101
public enum BlockReason {

vertexai/src/GenerativeModel.cs

Lines changed: 152 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,62 +16,193 @@
1616

1717
using System;
1818
using System.Collections.Generic;
19+
using System.Linq;
20+
using System.Net.Http;
21+
using System.Text;
1922
using System.Threading.Tasks;
23+
using Google.MiniJSON;
2024

2125
namespace Firebase.VertexAI {
2226

27+
/// <summary>
28+
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
29+
/// content based on various input types.
30+
/// </summary>
2331
public class GenerativeModel {
32+
private FirebaseApp _firebaseApp;
33+
34+
// Various setting fields provided by the user
35+
private string _location;
36+
private string _modelName;
37+
private GenerationConfig? _generationConfig;
38+
private SafetySetting[] _safetySettings;
39+
private Tool[] _tools;
40+
private ToolConfig? _toolConfig;
41+
private ModelContent? _systemInstruction;
42+
private RequestOptions? _requestOptions;
43+
44+
HttpClient _httpClient;
45+
46+
internal GenerativeModel(FirebaseApp firebaseApp,
47+
string location,
48+
string modelName,
49+
GenerationConfig? generationConfig = null,
50+
SafetySetting[] safetySettings = null,
51+
Tool[] tools = null,
52+
ToolConfig? toolConfig = null,
53+
ModelContent? systemInstruction = null,
54+
RequestOptions? requestOptions = null) {
55+
_firebaseApp = firebaseApp;
56+
_location = location;
57+
_modelName = modelName;
58+
_generationConfig = generationConfig;
59+
_safetySettings = safetySettings;
60+
_tools = tools;
61+
_toolConfig = toolConfig;
62+
_systemInstruction = systemInstruction;
63+
_requestOptions = requestOptions;
64+
65+
// Create a HttpClient using the timeout requested, or the default one.
66+
_httpClient = new HttpClient() {
67+
Timeout = requestOptions?.Timeout ?? RequestOptions.DefaultTimeout
68+
};
69+
}
70+
71+
#region Public API
72+
/// <summary>
73+
/// Generates new content from input ModelContent given to the model as a prompt.
74+
/// </summary>
75+
/// <param name="content">The input(s) given to the model as a prompt.</param>
76+
/// <returns>The generated content response from the model.</returns>
77+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
2478
public Task<GenerateContentResponse> GenerateContentAsync(
2579
params ModelContent[] content) {
26-
throw new NotImplementedException();
80+
return GenerateContentAsync((IEnumerable<ModelContent>)content);
2781
}
82+
/// <summary>
83+
/// Generates new content from input text given to the model as a prompt.
84+
/// </summary>
85+
/// <param name="content">The text given to the model as a prompt.</param>
86+
/// <returns>The generated content response from the model.</returns>
87+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
2888
public Task<GenerateContentResponse> GenerateContentAsync(
29-
IEnumerable<ModelContent> content) {
30-
throw new NotImplementedException();
89+
string text) {
90+
return GenerateContentAsync(new ModelContent[] { ModelContent.Text(text) });
3191
}
92+
/// <summary>
93+
/// Generates new content from input ModelContent given to the model as a prompt.
94+
/// </summary>
95+
/// <param name="content">The input(s) given to the model as a prompt.</param>
96+
/// <returns>The generated content response from the model.</returns>
97+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
3298
public Task<GenerateContentResponse> GenerateContentAsync(
33-
string text) {
34-
throw new NotImplementedException();
99+
IEnumerable<ModelContent> content) {
100+
return GenerateContentAsyncInternal(content);
35101
}
36102

37-
// The build logic isn't able to resolve IAsyncEnumerable for some reason, even
38-
// though it is usable in Unity 2021.3. Will need to investigate further.
39-
/*
103+
#define HIDE_IASYNCENUMERABLE
104+
#if !defined(HIDE_IASYNCENUMERABLE)
40105
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
41106
params ModelContent[] content) {
42-
throw new NotImplementedException();
107+
return GenerateContentStreamAsync((IEnumerable<ModelContent>)content);
43108
}
44109
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
45-
IEnumerable<ModelContent> content) {
46-
throw new NotImplementedException();
110+
string text) {
111+
return GenerateContentStreamAsync(new ModelContent[] { ModelContent.Text(text) });
47112
}
48113
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
49-
string text) {
50-
throw new NotImplementedException();
114+
IEnumerable<ModelContent> content) {
115+
return GenerateContentStreamAsyncInternal(content);
51116
}
52-
*/
117+
#endif
53118

54119
public Task<CountTokensResponse> CountTokensAsync(
55120
params ModelContent[] content) {
56-
throw new NotImplementedException();
121+
return CountTokensAsync((IEnumerable<ModelContent>)content);
57122
}
58123
public Task<CountTokensResponse> CountTokensAsync(
59-
IEnumerable<ModelContent> content) {
60-
throw new NotImplementedException();
124+
string text) {
125+
return CountTokensAsync(new ModelContent[] { ModelContent.Text(text) });
61126
}
62127
public Task<CountTokensResponse> CountTokensAsync(
63-
string text) {
64-
throw new NotImplementedException();
128+
IEnumerable<ModelContent> content) {
129+
return CountTokensAsyncInternal(content);
65130
}
66131

67132
public Chat StartChat(params ModelContent[] history) {
68-
throw new NotImplementedException();
133+
return StartChat((IEnumerable<ModelContent>)history);
69134
}
70135
public Chat StartChat(IEnumerable<ModelContent> history) {
136+
// TODO: Implementation
137+
throw new NotImplementedException();
138+
}
139+
#endregion
140+
141+
private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
142+
IEnumerable<ModelContent> content) {
143+
string bodyJson = ModelContentsToJson(content);
144+
145+
UnityEngine.Debug.Log($"Going to try to send: {bodyJson}");
146+
147+
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":generateContent");
148+
149+
// Set the request headers
150+
request.Headers.Add("x-goog-api-key", _firebaseApp.Options.ApiKey);
151+
request.Headers.Add("x-goog-api-client", "genai-csharp/0.1.0");
152+
153+
// Set the content
154+
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");
155+
156+
UnityEngine.Debug.Log("Request? " + request);
157+
158+
HttpResponseMessage response = await _httpClient.SendAsync(request);
159+
// TODO: Convert any timeout exception into a VertexAI equivalent
160+
// TODO: Convert any HttpRequestExceptions, see:
161+
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
162+
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
163+
response.EnsureSuccessStatusCode();
164+
165+
string result = await response.Content.ReadAsStringAsync();
166+
167+
UnityEngine.Debug.Log("Got a valid response at least: \n" + result);
168+
169+
return GenerateContentResponse.FromJson(result);
170+
}
171+
172+
#if !defined(HIDE_IASYNCENUMERABLE)
173+
private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsyncInternal(
174+
IEnumerable<ModelContent> content) {
175+
// TODO: Implementation
176+
await Task.CompletedTask;
177+
yield return new GenerateContentResponse();
71178
throw new NotImplementedException();
72179
}
180+
#endif
73181

74-
// Note: No public constructor, get one through VertexAI.GetGenerativeModel
182+
private async Task<CountTokensResponse> CountTokensAsyncInternal(
183+
IEnumerable<ModelContent> content) {
184+
// TODO: Implementation
185+
await Task.CompletedTask;
186+
throw new NotImplementedException();
187+
}
188+
189+
private string GetURL() {
190+
return "https://firebaseml.googleapis.com/v2beta" +
191+
"/projects/" + _firebaseApp.Options.ProjectId +
192+
"/locations/" + _location +
193+
"/publishers/google/models/" + _modelName;
194+
}
195+
196+
private string ModelContentsToJson(IEnumerable<ModelContent> contents) {
197+
Dictionary<string, object> jsonDict = new()
198+
{
199+
// Convert the Contents into a list of Json dictionaries
200+
["contents"] = contents.Select(c => c.ToJson()).ToList()
201+
// TODO: All the other settings
202+
};
203+
204+
return Json.Serialize(jsonDict);
205+
}
75206
}
76207

77208
}

0 commit comments

Comments
 (0)