|
1 | | -// |
2 | | -// ImagesDemoView.swift |
3 | | -// SwiftOpenAIExample |
4 | | -// |
5 | | -// Created by James Rochabrun on 10/24/23. |
6 | | -// |
7 | | - |
8 | 1 | import SwiftUI |
9 | 2 | import SwiftOpenAI |
10 | 3 |
|
11 | 4 | struct ImagesDemoView: View { |
12 | | - |
13 | | - @State private var imagesProvider: ImagesProvider |
14 | | - @State private var isLoading = false |
15 | | - @State private var prompt = "" |
16 | | - @State private var errorMessage = "" |
17 | | - |
18 | | - init(service: OpenAIService) { |
19 | | - _imagesProvider = State(initialValue: ImagesProvider(service: service)) |
20 | | - } |
21 | | - |
22 | | - var body: some View { |
23 | | - ScrollView { |
24 | | - textArea |
25 | | - if !errorMessage.isEmpty { |
26 | | - Text("Error \(errorMessage)") |
27 | | - .bold() |
28 | | - } |
29 | | - ForEach(Array(imagesProvider.images.enumerated()), id: \.offset) { _, url in |
30 | | - AsyncImage(url: url, scale: 1) { image in |
31 | | - image |
32 | | - .resizable() |
33 | | - .aspectRatio(contentMode: .fill) |
34 | | - .clipped() |
35 | | - } placeholder: { |
36 | | - EmptyView() |
| 5 | + enum ImageModel: String, CaseIterable, Identifiable { |
| 6 | + case gptImage1 = "GPT-Image-1" |
| 7 | + case dallE3 = "DALL-E 3" |
| 8 | + case dallE2 = "DALL-E 2" |
| 9 | + |
| 10 | + var id: String { self.rawValue } |
| 11 | + |
| 12 | + var model: CreateImageParameters.Model { |
| 13 | + switch self { |
| 14 | + case .gptImage1: return .gptImage1 |
| 15 | + case .dallE3: return .dallE3 |
| 16 | + case .dallE2: return .dallE2 |
| 17 | + } |
| 18 | + } |
| 19 | + } |
| 20 | + |
| 21 | + enum ImageQuality: String, CaseIterable, Identifiable { |
| 22 | + case high = "High" |
| 23 | + case medium = "Medium" |
| 24 | + case low = "Low" |
| 25 | + case standard = "Standard" |
| 26 | + case hd = "HD" |
| 27 | + |
| 28 | + var id: String { self.rawValue } |
| 29 | + |
| 30 | + var quality: CreateImageParameters.Quality { |
| 31 | + switch self { |
| 32 | + case .high: return .high |
| 33 | + case .medium: return .medium |
| 34 | + case .low: return .low |
| 35 | + case .standard: return .standard |
| 36 | + case .hd: return .hd |
37 | 37 | } |
38 | | - } |
39 | | - } |
40 | | - .overlay( |
41 | | - Group { |
42 | | - if isLoading { |
43 | | - ProgressView() |
44 | | - } else { |
45 | | - EmptyView() |
| 38 | + } |
| 39 | + } |
| 40 | + |
| 41 | + @State private var imagesProvider: ImagesProvider |
| 42 | + @State private var isLoading = false |
| 43 | + @State private var prompt = "" |
| 44 | + @State private var errorMessage = "" |
| 45 | + @State private var selectedModel: ImageModel = .gptImage1 |
| 46 | + @State private var selectedQuality: ImageQuality = .high |
| 47 | + @State private var selectedSize = "1024x1024" |
| 48 | + @State private var imageCount = 1 |
| 49 | + @State private var isAdvancedOptionsExpanded = false |
| 50 | + @State private var isShowingBase64Images = false |
| 51 | + |
| 52 | + init(service: OpenAIService) { |
| 53 | + _imagesProvider = State(initialValue: ImagesProvider(service: service)) |
| 54 | + } |
| 55 | + |
| 56 | + var body: some View { |
| 57 | + ScrollView { |
| 58 | + VStack(spacing: 20) { |
| 59 | + // Title |
| 60 | + Text("OpenAI Image Generation") |
| 61 | + .font(.title) |
| 62 | + .padding(.top) |
| 63 | + |
| 64 | + // Prompt input area |
| 65 | + promptInputArea |
| 66 | + |
| 67 | + // Advanced options (collapsible) |
| 68 | + advancedOptionsArea |
| 69 | + |
| 70 | + // Generate button |
| 71 | + generateButton |
| 72 | + |
| 73 | + // Error message |
| 74 | + if !errorMessage.isEmpty { |
| 75 | + Text("Error: \(errorMessage)") |
| 76 | + .foregroundColor(.red) |
| 77 | + .bold() |
| 78 | + .padding() |
| 79 | + } |
| 80 | + |
| 81 | + // Images display |
| 82 | + imageResultsArea |
46 | 83 | } |
47 | | - } |
48 | | - ) |
49 | | - } |
50 | | - |
51 | | - var textArea: some View { |
52 | | - HStack(spacing: 4) { |
53 | | - TextField("Enter prompt", text: $prompt, axis: .vertical) |
54 | | - .textFieldStyle(.roundedBorder) |
55 | 84 | .padding() |
56 | | - Button { |
| 85 | + } |
| 86 | + .overlay( |
| 87 | + Group { |
| 88 | + if isLoading { |
| 89 | + ZStack { |
| 90 | + Color.black.opacity(0.4) |
| 91 | + VStack { |
| 92 | + ProgressView() |
| 93 | + .scaleEffect(1.5) |
| 94 | + .padding() |
| 95 | + Text("Generating images...") |
| 96 | + .foregroundColor(.white) |
| 97 | + } |
| 98 | + .padding() |
| 99 | + .background(RoundedRectangle(cornerRadius: 10).fill(Color.gray.opacity(0.7))) |
| 100 | + } |
| 101 | + .edgesIgnoringSafeArea(.all) |
| 102 | + } |
| 103 | + } |
| 104 | + ) |
| 105 | + } |
| 106 | + |
| 107 | + private var promptInputArea: some View { |
| 108 | + VStack(alignment: .leading, spacing: 8) { |
| 109 | + Text("Enter a prompt") |
| 110 | + .font(.headline) |
| 111 | + |
| 112 | + TextField("Describe what you want to generate...", text: $prompt, axis: .vertical) |
| 113 | + .textFieldStyle(.roundedBorder) |
| 114 | + .lineLimit(3...6) |
| 115 | + .padding(.bottom, 8) |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + private var advancedOptionsArea: some View { |
| 120 | + VStack(alignment: .leading, spacing: 8) { |
| 121 | + Button { |
| 122 | + withAnimation { |
| 123 | + isAdvancedOptionsExpanded.toggle() |
| 124 | + } |
| 125 | + } label: { |
| 126 | + HStack { |
| 127 | + Text("Advanced Options") |
| 128 | + .font(.headline) |
| 129 | + Spacer() |
| 130 | + Image(systemName: isAdvancedOptionsExpanded ? "chevron.up" : "chevron.down") |
| 131 | + .animation(.default, value: isAdvancedOptionsExpanded) |
| 132 | + } |
| 133 | + } |
| 134 | + .foregroundColor(.primary) |
| 135 | + |
| 136 | + if isAdvancedOptionsExpanded { |
| 137 | + VStack(spacing: 16) { |
| 138 | + // Model picker |
| 139 | + HStack { |
| 140 | + Text("Model:") |
| 141 | + .frame(width: 100, alignment: .leading) |
| 142 | + |
| 143 | + Picker("Select Model", selection: $selectedModel) { |
| 144 | + ForEach(ImageModel.allCases) { model in |
| 145 | + Text(model.rawValue).tag(model) |
| 146 | + } |
| 147 | + } |
| 148 | + .pickerStyle(MenuPickerStyle()) |
| 149 | + } |
| 150 | + |
| 151 | + // Quality picker |
| 152 | + HStack { |
| 153 | + Text("Quality:") |
| 154 | + .frame(width: 100, alignment: .leading) |
| 155 | + |
| 156 | + Picker("Select Quality", selection: $selectedQuality) { |
| 157 | + ForEach(ImageQuality.allCases) { quality in |
| 158 | + Text(quality.rawValue).tag(quality) |
| 159 | + } |
| 160 | + } |
| 161 | + .pickerStyle(MenuPickerStyle()) |
| 162 | + } |
| 163 | + |
| 164 | + // Size picker |
| 165 | + HStack { |
| 166 | + Text("Size:") |
| 167 | + .frame(width: 100, alignment: .leading) |
| 168 | + |
| 169 | + Picker("Select Size", selection: $selectedSize) { |
| 170 | + Text("1024x1024").tag("1024x1024") |
| 171 | + Text("1536x1024").tag("1536x1024") |
| 172 | + Text("1024x1536").tag("1024x1536") |
| 173 | + } |
| 174 | + .pickerStyle(MenuPickerStyle()) |
| 175 | + } |
| 176 | + |
| 177 | + // Image count stepper |
| 178 | + HStack { |
| 179 | + Text("Count:") |
| 180 | + .frame(width: 100, alignment: .leading) |
| 181 | + |
| 182 | + Stepper("\(imageCount) \(imageCount == 1 ? "image" : "images")", value: $imageCount, in: 1...4) |
| 183 | + } |
| 184 | + } |
| 185 | + .padding() |
| 186 | + .background(Color.gray.opacity(0.1)) |
| 187 | + .cornerRadius(8) |
| 188 | + } |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + private var generateButton: some View { |
| 193 | + Button { |
57 | 194 | Task { |
58 | | - isLoading = true |
59 | | - defer { isLoading = false } // ensure isLoading is set to false when the |
60 | | - do { |
61 | | - try await imagesProvider.createImages(parameters: .init(prompt: prompt, model: .dalle3(.largeSquare))) |
62 | | - } catch { |
63 | | - errorMessage = "\(error)" |
64 | | - } |
| 195 | + await generateImages() |
| 196 | + } |
| 197 | + } label: { |
| 198 | + Text("Generate Images") |
| 199 | + .frame(maxWidth: .infinity) |
| 200 | + .padding() |
| 201 | + .background(prompt.isEmpty ? Color.gray : Color.blue) |
| 202 | + .foregroundColor(.white) |
| 203 | + .cornerRadius(10) |
| 204 | + } |
| 205 | + .disabled(prompt.isEmpty || isLoading) |
| 206 | + } |
| 207 | + |
| 208 | + private var imageResultsArea: some View { |
| 209 | + VStack(alignment: .leading, spacing: 16) { |
| 210 | + if !imagesProvider.images.isEmpty || !imagesProvider.base64Images.isEmpty { |
| 211 | + HStack { |
| 212 | + Text("Generated Images") |
| 213 | + .font(.headline) |
| 214 | + Spacer() |
| 215 | + |
| 216 | + if !imagesProvider.base64Images.isEmpty { |
| 217 | + Button { |
| 218 | + isShowingBase64Images.toggle() |
| 219 | + } label: { |
| 220 | + Text(isShowingBase64Images ? "Show URL images" : "Show base64 images") |
| 221 | + .font(.caption) |
| 222 | + } |
| 223 | + } |
| 224 | + } |
| 225 | + |
| 226 | + if isShowingBase64Images { |
| 227 | + // Display base64 images |
| 228 | + let uiImages = imagesProvider.getUIImagesFromBase64() |
| 229 | + |
| 230 | + if uiImages.isEmpty { |
| 231 | + Text("No base64 images available") |
| 232 | + .foregroundColor(.gray) |
| 233 | + .padding() |
| 234 | + } else { |
| 235 | + LazyVGrid(columns: [GridItem(.adaptive(minimum: 150), spacing: 16)], spacing: 16) { |
| 236 | + ForEach(Array(uiImages.enumerated()), id: \.offset) { index, image in |
| 237 | + Image(uiImage: image) |
| 238 | + .resizable() |
| 239 | + .aspectRatio(contentMode: .fill) |
| 240 | + .frame(height: 200) |
| 241 | + .clipShape(RoundedRectangle(cornerRadius: 10)) |
| 242 | + .shadow(radius: 5) |
| 243 | + } |
| 244 | + } |
| 245 | + } |
| 246 | + } else { |
| 247 | + // Display URL images |
| 248 | + if imagesProvider.images.isEmpty { |
| 249 | + Text("No URL images available") |
| 250 | + .foregroundColor(.gray) |
| 251 | + .padding() |
| 252 | + } else { |
| 253 | + LazyVGrid(columns: [GridItem(.adaptive(minimum: 150), spacing: 16)], spacing: 16) { |
| 254 | + ForEach(Array(imagesProvider.images.enumerated()), id: \.offset) { index, url in |
| 255 | + AsyncImage(url: url) { phase in |
| 256 | + switch phase { |
| 257 | + case .empty: |
| 258 | + ProgressView() |
| 259 | + case .success(let image): |
| 260 | + image |
| 261 | + .resizable() |
| 262 | + .aspectRatio(contentMode: .fill) |
| 263 | + .frame(height: 200) |
| 264 | + .clipShape(RoundedRectangle(cornerRadius: 10)) |
| 265 | + case .failure: |
| 266 | + Image(systemName: "exclamationmark.triangle") |
| 267 | + .font(.largeTitle) |
| 268 | + .frame(height: 200) |
| 269 | + @unknown default: |
| 270 | + EmptyView() |
| 271 | + } |
| 272 | + } |
| 273 | + .shadow(radius: 5) |
| 274 | + } |
| 275 | + } |
| 276 | + } |
| 277 | + } |
| 278 | + } |
| 279 | + } |
| 280 | + } |
| 281 | + |
| 282 | + private func generateImages() async { |
| 283 | + guard !prompt.isEmpty else { return } |
| 284 | + |
| 285 | + isLoading = true |
| 286 | + errorMessage = "" |
| 287 | + |
| 288 | + do { |
| 289 | + let parameters = CreateImageParameters( |
| 290 | + prompt: prompt, |
| 291 | + model: selectedModel.model, |
| 292 | + n: imageCount, |
| 293 | + // quality: selectedQuality.quality, |
| 294 | + size: selectedSize |
| 295 | + ) |
| 296 | + |
| 297 | + try await imagesProvider.createImages(parameters: parameters) |
| 298 | + |
| 299 | + // If we got base64 images but no URL images, automatically show base64 images |
| 300 | + if imagesProvider.images.isEmpty && !imagesProvider.base64Images.isEmpty { |
| 301 | + isShowingBase64Images = true |
65 | 302 | } |
66 | | - } label: { |
67 | | - Image(systemName: "paperplane") |
68 | | - } |
69 | | - .buttonStyle(.bordered) |
70 | | - } |
71 | | - .padding() |
72 | | - } |
| 303 | + |
| 304 | + } catch { |
| 305 | + errorMessage = error.localizedDescription |
| 306 | + } |
| 307 | + |
| 308 | + isLoading = false |
| 309 | + } |
73 | 310 | } |
0 commit comments