diff --git a/package-lock.json b/package-lock.json index 89eb096..ca2e846 100644 --- a/package-lock.json +++ b/package-lock.json @@ -17,6 +17,7 @@ "@fontsource-variable/inter": "^5.1.0", "@fontsource-variable/jetbrains-mono": "^5.1.1", "@grpc/grpc-js": "^1.13.4", + "@grpc/reflection": "^1.0.4", "@styled/typescript-styled-plugin": "^1.0.1", "@types/react": "^19.0.0", "@types/react-dom": "^19.0.0", @@ -1628,6 +1629,19 @@ "node": ">=6" } }, + "node_modules/@grpc/reflection": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@grpc/reflection/-/reflection-1.0.4.tgz", + "integrity": "sha512-znA8v4AviOD3OPOxy11pxrtP8k8DanpefeTymS8iGW1fVr1U2cHuzfhYqDPHnVNDf4qvF9E25KtSihPy2DBWfQ==", + "dev": true, + "dependencies": { + "@grpc/proto-loader": "^0.7.13", + "protobufjs": "^7.2.5" + }, + "peerDependencies": { + "@grpc/grpc-js": "^1.8.21" + } + }, "node_modules/@isaacs/cliui": { "version": "8.0.2", "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", diff --git a/package.json b/package.json index a7b441c..aa3135b 100644 --- a/package.json +++ b/package.json @@ -31,6 +31,7 @@ "@fontsource-variable/inter": "^5.1.0", "@fontsource-variable/jetbrains-mono": "^5.1.1", "@grpc/grpc-js": "^1.13.4", + "@grpc/reflection": "^1.0.4", "@styled/typescript-styled-plugin": "^1.0.1", "@types/react": "^19.0.0", "@types/react-dom": "^19.0.0", diff --git a/src/common/ipc.ts b/src/common/ipc.ts index 00bc019..c9f367c 100644 --- a/src/common/ipc.ts +++ b/src/common/ipc.ts @@ -13,6 +13,7 @@ export enum IpcCall { HttpRequest = "http-request", GrpcRequest = "grpc-request", + GrpcReflection = "get-methods-grpc-reflection", AbortRequest = "abort-request", } diff --git a/src/common/proto.ts b/src/common/proto.ts index c74c1eb..dfc1930 100644 --- a/src/common/proto.ts +++ b/src/common/proto.ts @@ -3,7 +3,7 @@ type ProtoObject = ProtoMessageDescriptor | ProtoRepeated | ProtoOneOf | ProtoLi interface ProtoMessageDescriptor { type: "message"; name: string; - fields: Record; + fields: Record; } interface ProtoRepeated { @@ -13,7 +13,7 @@ interface ProtoRepeated { interface ProtoOneOf { type: "oneof"; - fields: Record; + fields: Record; } interface ProtoLiteral { @@ -31,3 +31,9 @@ interface ProtoEnum { name: string; values: Array<{ value: number; name: string }>; } + +interface ProtoField { + id: number; + name: string; + type: ProtoObject; +} diff --git a/src/common/request-types.ts b/src/common/request-types.ts index 5cb5284..84172ca 100644 --- a/src/common/request-types.ts +++ b/src/common/request-types.ts @@ -36,6 +36,7 @@ export interface GrpcRequestData { lastExecute: number; isExecuting: boolean; history: GrpcRequestData[]; + useReflection?: boolean; kind?: GrpcRequestKind; protoFile?: { protoPath: string; rootDir: string }; diff --git a/src/electron/Communication/grpc-reflection.ts b/src/electron/Communication/grpc-reflection.ts new file mode 100644 index 0000000..67288f8 --- /dev/null +++ b/src/electron/Communication/grpc-reflection.ts @@ -0,0 +1,333 @@ +import * as path from "node:path"; +import type * as grpc from "@grpc/grpc-js"; +import * as proto from "@grpc/proto-loader"; +import * as protobufjs from "protobufjs"; +import type * as protobuf_descriptor from "protobufjs/ext/descriptor"; +import type { ServerReflectionRequest } from "@grpc/reflection/build/src/generated/grpc/reflection/v1/ServerReflectionRequest"; +import type { ServerReflectionResponse } from "@grpc/reflection/build/src/generated/grpc/reflection/v1/ServerReflectionResponse"; +import type { FileDescriptorResponse } from "@grpc/reflection/build/src/generated/grpc/reflection/v1/FileDescriptorResponse"; +import type { ServiceResponse } from "@grpc/reflection/build/src/generated/grpc/reflection/v1/ServiceResponse"; +import type { MethodInfo, ProtoService } from "../../common/grpc"; + +let ReflectionService: proto.ServiceDefinition = undefined!; +let FileDescriptorProto: protobufjs.Type = undefined!; +async function loadPrerequisites(): Promise { + if (!ReflectionService) { + // Parse the reflection proto to call the service + const parsedProto = await proto.load("reflection.proto", { + includeDirs: ["node_modules/@grpc/reflection/build/proto/grpc/reflection/v1"], + }); + ReflectionService = parsedProto["grpc.reflection.v1.ServerReflection"] as proto.ServiceDefinition; + + // parse the descriptor proto to decode the reflection messages + const root = await protobufjs.load(path.join("node_modules/protobufjs", "google/protobuf/descriptor.proto")); + FileDescriptorProto = root.lookupType("google.protobuf.FileDescriptorProto"); + } +} + +const reflectionCache = new Map(); + +export class GrpcReflectionHandler { + private pendingRequests = 0; + private stream: grpc.ClientDuplexStream = undefined!; + private services: protobuf_descriptor.IServiceDescriptorProto[] = []; + private knownTypes = new Map(); + private knownEnums = new Map(); + + private constructor(private client: grpc.Client) { + this.client = client; + } + + static async fetchServicesFromServer(client: grpc.Client): Promise> { + const handler = new GrpcReflectionHandler(client); + await loadPrerequisites(); + reflectionCache.clear(); + return await handler.fetchServices(); + } + + static async getMethodInfo(client: grpc.Client, service: string, method: string): Promise { + const methodPath = methodKey(service, method); + let methodInfo = reflectionCache.get(methodPath); + if (methodInfo) return methodInfo; + + // Method not found in cache, re-execute the reflection call to make sure cache is up to date + await GrpcReflectionHandler.fetchServicesFromServer(client); + + methodInfo = reflectionCache.get(methodPath); + if (methodInfo === undefined) { + throw `Failed to find method info for ${methodPath} through reflection`; + } + return methodInfo; + } + + async fetchServices(): Promise> { + return new Promise((resolve) => { + const method = ReflectionService.ServerReflectionInfo; + this.stream = this.client.makeBidiStreamRequest( + method.path, + method.requestSerialize, + method.responseDeserialize, + ); + + // Request a list of all the services + this.requestServiceList(); + + this.stream.on("data", async (response: ServerReflectionResponse) => { + if (response?.listServicesResponse?.service) { + this.handleListServicesResponse(response.listServicesResponse.service); + } else if (response?.fileDescriptorResponse) { + this.handleFileDescriptorResponse(response.fileDescriptorResponse); + } + + this.pendingRequests--; + if (this.pendingRequests === 0) { + this.stream.end(); + } + }); + + this.stream.on("error", (err) => { + resolve({ success: false, error: err.message }); + }); + + this.stream.on("end", () => { + // Map the reflection info to the info we need + resolve({ success: true, value: this.mapServices() }); + }); + }); + } + + handleListServicesResponse(services: ServiceResponse[]): void { + // Request the type of all the received services + for (const service of services) { + this.requestType(service.name!); + } + } + + handleFileDescriptorResponse(response: FileDescriptorResponse): void { + const fileDescBytes = response.fileDescriptorProto?.[0] as Buffer; + // Use protobufjs to decode the descriptor bytes + const fileDesc = FileDescriptorProto.decode(fileDescBytes) as protobuf_descriptor.IFileDescriptorProto; + + // First read all message and types from this file so we don't re-request them later + if (fileDesc.enumType) { + for (const enumType of fileDesc.enumType) { + // Save both the local and globally specified name + this.knownEnums.set(enumType.name!, enumType); + this.knownEnums.set(`.${fileDesc.package}.${enumType.name}`, enumType); + } + } + if (fileDesc.messageType) { + for (const msg of fileDesc.messageType) { + // Save both the local name and globally specified name + this.knownTypes.set(msg.name!, msg); + this.knownTypes.set(`.${fileDesc.package}.${msg.name}`, msg); + if (msg.enumType) { + // Also save the nested enums + for (const enumType of msg.enumType) { + this.knownEnums.set(enumType.name!, enumType); + } + } + if (msg.nestedType) { + // Also save nested messages + for (const nestedType of msg.nestedType) { + this.knownTypes.set(nestedType.name!, nestedType); + } + } + } + } + // Then read all the services and see which types we need to request as follow-up + if (fileDesc?.service) { + for (const svc of fileDesc.service) { + svc.name = `${fileDesc.package}.${svc.name}`; + this.services.push(svc); + for (const method of svc.method ?? []) { + if (method.inputType && !this.knownTypes.has(method.inputType)) { + // Request the file containing this unknown type + this.requestType(method.inputType); + } + if (method.outputType && !this.knownTypes.has(method.outputType)) { + // Request the file containing this unknown type + this.requestType(method.outputType); + } + } + } + } + } + + requestServiceList(): void { + this.stream.write({ + listServices: "", + }); + this.pendingRequests++; + } + + requestType(typeName: string): void { + this.stream.write({ + fileContainingSymbol: typeName, + }); + this.pendingRequests++; + } + + mapServices(): ProtoService[] { + const services = this.services.map((svc) => ({ + name: svc.name!, + methods: (svc.method ?? []).map((method) => ({ + name: method.name!, + requestStream: method.clientStreaming ?? false, + serverStream: method.serverStreaming ?? false, + requestType: this.mapMessageType(method.inputType!), + responseType: this.mapMessageType(method.outputType!), + })), + })); + + for (const svc of services) { + for (const method of svc.methods) { + reflectionCache.set(methodKey(svc.name, method.name), method); + } + } + + return services; + } + + mapMessageType(typeName: string): ProtoMessageDescriptor { + const type = this.knownTypes.get(typeName); + if (!type) { + throw `unknown type ${typeName}`; + } + + const fields: ProtoMessageDescriptor["fields"] = {}; + const oneOfs = new Map(); + for (const field of type.field ?? []) { + let innerType = this.mapMessageField(field); + const inOneOf = Object.hasOwn(field, "oneofIndex"); + if (field.label === ProtoFieldLabel.LABEL_OPTIONAL && !inOneOf) { + innerType = { type: "optional", optionalType: innerType }; + } + if (field.label === ProtoFieldLabel.LABEL_REPEATED) { + innerType = { type: "repeated", repeatedType: innerType }; + } + + if (inOneOf && field.oneofIndex !== undefined) { + if (!oneOfs.has(field.oneofIndex)) { + oneOfs.set(field.oneofIndex, []); + } + oneOfs.get(field.oneofIndex)!.push({ + id: field.number!, + name: field.name!, + type: innerType, + }); + } else { + fields[field.name!] = { + id: field.number!, + name: field.name!, + type: innerType, + }; + } + } + + let oneOfId = 0; + for (const oneof of type.oneofDecl ?? []) { + const oneOfFields: ProtoOneOf["fields"] = {}; + const oneOfValues = oneOfs.get(oneOfId) ?? []; + + // If only one value, assume optional: + if (oneOfValues.length === 1) { + const value = oneOfValues[0]; + fields[value.name] = { + id: value.id, + name: value.name, + type: { + type: "optional", + optionalType: value.type, + }, + }; + } else { + for (const field of oneOfValues) { + oneOfFields[field.name] = field; + } + fields[oneof.name!] = { + id: -1, + name: oneof.name!, + type: { + type: "oneof", + fields: oneOfFields, + }, + }; + } + oneOfId++; + } + + return { + type: "message", + name: type.name!, + fields, + }; + } + + mapMessageField(field: protobuf_descriptor.IFieldDescriptorProto): ProtoObject { + if (field.type === ProtoFieldType.TYPE_MESSAGE) { + return this.mapMessageType(field.typeName!); + } + if (field.type === ProtoFieldType.TYPE_ENUM) { + const enumType = this.knownEnums.get(field.typeName!); + if (!enumType) { + throw `Unknown enum type ${field.typeName}`; + } + return { + type: "enum", + name: field.typeName!, + values: + enumType.value?.map((v) => ({ + name: v.name!, + value: 0, + })) ?? [], + }; + } + return { + type: "literal", + literalType: mapLiteralType(field.type!), + }; + } +} + +function methodKey(service: string, method: string): string { + return `${service}/${method}`; +} + +/** + * Enum values declared in protobufjs\ext\descriptor\index.js + */ +enum ProtoFieldLabel { + LABEL_OPTIONAL = 1, + LABEL_REQUIRED = 2, + LABEL_REPEATED = 3, +} +enum ProtoFieldType { + TYPE_DOUBLE = 1, + TYPE_FLOAT = 2, + TYPE_INT64 = 3, + TYPE_UINT64 = 4, + TYPE_INT32 = 5, + TYPE_FIXED64 = 6, + TYPE_FIXED32 = 7, + TYPE_BOOL = 8, + TYPE_STRING = 9, + TYPE_GROUP = 10, + TYPE_MESSAGE = 11, + TYPE_BYTES = 12, + TYPE_UINT32 = 13, + TYPE_ENUM = 14, + TYPE_SFIXED32 = 15, + TYPE_SFIXED64 = 16, + TYPE_SINT32 = 17, + TYPE_SINT64 = 18, +} + +function mapLiteralType(type: protobuf_descriptor.IFieldDescriptorProtoType): string { + switch (type) { + case ProtoFieldType.TYPE_STRING: + return "string"; + default: + return ProtoFieldType[type].split("_")[1].toLowerCase(); + } +} diff --git a/src/electron/Communication/grpc.ts b/src/electron/Communication/grpc.ts index 20aabbf..aba0fc6 100644 --- a/src/electron/Communication/grpc.ts +++ b/src/electron/Communication/grpc.ts @@ -2,54 +2,222 @@ import * as fs from "node:fs/promises"; import * as path from "node:path"; import * as grpc from "@grpc/grpc-js"; import * as proto from "@grpc/proto-loader"; +import * as protobuf from "protobufjs"; import { dialog } from "electron"; import JSON5 from "json5"; -import type { GrpcServerStreamDataEvent, GrpcServerStreamErrorEvent, GrpcStreamClosedEvent } from "../../common/grpc"; +import type { + GrpcServerStreamDataEvent, + GrpcServerStreamErrorEvent, + GrpcStreamClosedEvent, + ProtoService, +} from "../../common/grpc"; import { type BrowseProtoResult, IpcEvent } from "../../common/ipc"; import type { GrpcRequestData, GrpcResponse, GrpcServerStreamData, RequestId } from "../../common/request-types"; +import { GrpcReflectionHandler } from "./grpc-reflection"; const RequestCancelHandles: Partial void>> = {}; export async function makeGrpcRequest(request: GrpcRequestData, ipc: Electron.WebContents): Promise { - if (!request.protoFile || !request.rpc) { - throw "invalid request"; + if (!request.rpc) { + return { + result: "error", + code: "INVALID", + detail: "Invalid rpc data", + time: 0, + }; } const GenericClient = grpc.makeGenericClientConstructor({}, ""); const client = new GenericClient(request.url, grpc.credentials.createInsecure()); - const parsedProto = await parseProtoPackageDescription(request.protoFile.protoPath, request.protoFile.rootDir); - const service = parsedProto[request.rpc.service]; + if (request.protoFile) { + const parsedProto = await parseProtoPackageDescription(request.protoFile.protoPath, request.protoFile.rootDir); + const service = parsedProto[request.rpc.service]; + + if (service && isServiceDefinition(service)) { + const method = service[request.rpc.method]; + if (method) { + if (!method.requestStream && !method.responseStream) { + return await grpcUnaryRequest( + request, + method.path, + method.requestSerialize, + method.responseDeserialize, + client, + ); + } + + if (!method.requestStream && method.responseStream) { + return await grpcServerStreamingRequest( + request, + method.path, + method.requestSerialize, + method.responseDeserialize, + client, + ipc, + ); + } + + return { + result: "error", + code: "INVALID", + detail: "Request streaming is not (yet) supported", + time: 0, + }; + } + } + } else { + //reflection + const methodInfo = await GrpcReflectionHandler.getMethodInfo(client, request.rpc.service, request.rpc.method); + if (methodInfo.requestType === undefined || methodInfo.responseType === undefined) { + return { + result: "error", + code: "INVALID", + detail: "Reflected method information could not figure out request or response type", + time: 0, + }; + } - if (service && isServiceDefinition(service)) { - const method = service[request.rpc.method]; - if (method) { - if (!method.requestStream && !method.responseStream) { - return await grpcUnaryRequest(request, method, client); + const typeName = (message: ProtoObject): string => { + switch (message.type) { + case "message": + return message.name; + case "enum": + return message.name; + case "literal": + return message.literalType; + case "optional": + return typeName(message.optionalType); + case "repeated": + return typeName(message.repeatedType); + default: + throw `Don't know how to get name for ${message.type} proto object!`; } + }; + const messageDescriptorToProtoType = (message: ProtoMessageDescriptor): protobuf.Type => { + const type = new protobuf.Type(message.name); + + const handleField = (field: ProtoField) => { + if (field.type.type === "oneof") { + if (type.oneofs === undefined) type.oneofs = {}; + + const oneof = new protobuf.OneOf(field.name); + type.oneofs[field.name] = oneof; + type.oneofsArray.push(type.oneofs[field.name]); + + for (const [name, oneOfField] of Object.entries(field.type.fields)) { + handleField(oneOfField); + const f = type.fields[oneOfField.name]; + oneof.fieldsArray.push(f); + f.optional = true; + f.required = false; + f.partOf = oneof; + } + return; + } + + let innerType: ProtoObject | undefined = undefined; + if (field.type.type === "message") { + innerType = field.type; + } else if (field.type.type === "optional") { + innerType = field.type.optionalType; + } else if (field.type.type === "repeated") { + innerType = field.type.repeatedType; + } + + if (!innerType) { + throw `Could not determine type of proto field ${field.name}=${field.id}! ${JSON.stringify(field.type)}`; + } + + type.fields[field.name] = new protobuf.Field(field.name, field.id, typeName(innerType)); + + if (field.type.type === "optional") { + type.fields[field.name].optional = true; + } else if (field.type.type === "repeated") { + type.fields[field.name].repeated = true; + } else { + type.fields[field.name].optional = false; + type.fields[field.name].required = true; + } - if (!method.requestStream && method.responseStream) { - return await grpcServerStreamingRequest(request, method, client, ipc); + if (innerType.type === "message") { + // Fill in resolved types because we can't resolve in the reflected information + type.fields[field.name].resolved = true; + type.fields[field.name].resolvedType = messageDescriptorToProtoType(innerType); + } else if (innerType.type === "enum") { + type.fields[field.name].resolved = true; + const enumType = new protobuf.Enum(innerType.name); + for (const v of innerType.values) { + enumType.values[v.name] = v.value; + } + type.fields[field.name].resolvedType = enumType; + } + }; + + for (const field of Object.values(message.fields)) { + if (!field) continue; + handleField(field); } + return type; + }; + + const requestMessage = messageDescriptorToProtoType(methodInfo.requestType); + const responseMessage = messageDescriptorToProtoType(methodInfo.responseType); + const methodPath = `/${request.rpc.service}/${request.rpc.method}`; - return { result: "error", code: "INVALID", detail: "Request streaming is not (yet) supported", time: 0 }; + if (!methodInfo.requestStream && !methodInfo.serverStream) { + return await grpcUnaryRequest( + request, + methodPath, + (o) => Buffer.from(requestMessage.encode(o).finish()), + (b) => responseMessage.decode(b), + client, + ); } + + if (!methodInfo.requestStream && methodInfo.serverStream) { + return await grpcServerStreamingRequest( + request, + methodPath, + (o) => Buffer.from(requestMessage.encode(o).finish()), + (b) => responseMessage.decode(b), + client, + ipc, + ); + } + + return { + result: "error", + code: "INVALID", + detail: "Request streaming is not (yet) supported", + time: 0, + }; } return { result: "error", code: "INVALID", detail: "Invalid request", time: 0 }; } +export async function getMethodsViaReflection(serverUrl: string): Promise> { + const GenericClient = grpc.makeGenericClientConstructor({}, ""); + const client = new GenericClient(serverUrl, grpc.credentials.createInsecure()); + + // This protocol is so complicated it was moved to its own file + return GrpcReflectionHandler.fetchServicesFromServer(client); +} + function grpcUnaryRequest( request: GrpcRequestData, - method: proto.MethodDefinition, + path: string, + requestSerialize: (obj: object) => Buffer, + responseDeserialize: (b: Buffer) => object, client: grpc.Client, ): Promise { return new Promise((resolve) => { const start = performance.now(); const call = client.makeUnaryRequest( - method.path, - method.requestSerialize, - (r) => JSON.stringify(method.responseDeserialize(r), null, 2), + path, + requestSerialize, + (r) => JSON.stringify(responseDeserialize(r), null, 2), parseRequestBody(request.body), (err: grpc.ServiceError | null, value?: string) => { delete RequestCancelHandles[request.id]; @@ -81,14 +249,16 @@ function grpcUnaryRequest( async function grpcServerStreamingRequest( request: GrpcRequestData, - method: proto.MethodDefinition, + path: string, + requestSerialize: (obj: object) => Buffer, + responseDeserialize: (b: Buffer) => object, client: grpc.Client, ipc: Electron.WebContents, ): Promise { const stream = client.makeServerStreamRequest( - method.path, - method.requestSerialize, - method.responseDeserialize, + path, + requestSerialize, + responseDeserialize, parseRequestBody(request.body), ); @@ -176,7 +346,7 @@ export async function findProtoFiles(protoRoot: string): Promise { } function parseProtoPackageDescription(protoPath: string, protoRootDir: string): Promise { - return proto.load(path.join(protoRootDir, protoPath), { includeDirs: [protoRootDir] }); + return proto.load(protoPath, { includeDirs: [protoRootDir] }); } function isServiceDefinition(desc: proto.AnyDefinition): desc is proto.ServiceDefinition { diff --git a/src/electron/Communication/proto.ts b/src/electron/Communication/proto.ts index 1c4ae83..997e203 100644 --- a/src/electron/Communication/proto.ts +++ b/src/electron/Communication/proto.ts @@ -82,21 +82,37 @@ function mapType(type: protobufjs.Type | protobufjs.Enum): ProtoObject { } // Else: Assume message - const members: Record = {}; + const members: ProtoMessageDescriptor["fields"] = {}; for (const field of type.fieldsArray) { const fieldType = field.resolvedType ? mapType(field.resolvedType) : mapLiteralType(field.type); if (isOptionalField(field)) { - members[field.name] = fieldType ? { type: "optional", optionalType: fieldType } : undefined; + members[field.name] = fieldType + ? { + id: field.id, + name: field.name, + type: { type: "optional", optionalType: fieldType }, + } + : undefined; continue; } if (field.partOf instanceof protobufjs.OneOf) continue; // Skip oneofs for now if (field.repeated && fieldType) { - members[field.name] = { type: "repeated", repeatedType: fieldType }; + members[field.name] = { + id: field.id, + name: field.name, + type: { type: "repeated", repeatedType: fieldType }, + }; } else { - members[field.name] = fieldType; + members[field.name] = fieldType + ? { + id: field.id, + name: field.name, + type: fieldType, + } + : undefined; } } @@ -107,16 +123,20 @@ function mapType(type: protobufjs.Type | protobufjs.Enum): ProtoObject { continue; } - const oneOfMembers: Record = {}; + const oneOfMembers: ProtoOneOf["fields"] = {}; for (const field of oneof.fieldsArray) { const nestedFieldType = field.resolvedType ? mapType(field.resolvedType) : mapLiteralType(field.type); if (nestedFieldType) { - oneOfMembers[field.name] = nestedFieldType; + oneOfMembers[field.name] = { + id: field.id, + name: field.name, + type: nestedFieldType, + }; } } - members[oneof.name] = { type: "oneof", fields: oneOfMembers }; + members[oneof.name] = { id: 0, name: oneof.name, type: { type: "oneof", fields: oneOfMembers } }; } return { diff --git a/src/electron/Storage/persist.ts b/src/electron/Storage/persist.ts index a136274..400b530 100644 --- a/src/electron/Storage/persist.ts +++ b/src/electron/Storage/persist.ts @@ -110,18 +110,14 @@ function fixPersistedData( url: ri.url ?? "", kind: ri.kind, + useReflection: ri.useReflection ?? false, protoFile: ri.protoFile ? { protoPath: ri.protoFile?.protoPath ?? "", rootDir: ri.protoFile?.rootDir ?? "", } : undefined, - rpc: ri.rpc - ? { - service: ri.rpc?.service ?? "", - method: ri.rpc?.method ?? "", - } - : undefined, + rpc: ri.rpc ? fixGrpcRpc(ri.rpc) : undefined, body: ri.body ?? "{}", @@ -131,6 +127,17 @@ function fixPersistedData( }; } + function fixGrpcRpc(m: DeepPartial): GrpcRequestData["rpc"] { + if (m?.service === undefined) return undefined; + if (m?.method === undefined) return undefined; + if (typeof m.method !== "string") return undefined; + + return { + service: m.service, + method: m.method, + }; + } + function fixGroup(ri: DeepPartial): RequestGroup { return { type: "group", diff --git a/src/electron/electron-app.ts b/src/electron/electron-app.ts index 9b17f34..ef8c6cd 100644 --- a/src/electron/electron-app.ts +++ b/src/electron/electron-app.ts @@ -3,16 +3,15 @@ import { BrowserWindow, app, ipcMain, nativeTheme } from "electron"; import type { ProtoContent, ProtoRoot } from "../common/grpc"; import { type BrowseProtoResult, IpcCall, IpcEvent } from "../common/ipc"; import type { PersistedState } from "../common/persist-state"; -import type { - GrpcRequestData, - GrpcResponse, - HttpRequestData, - HttpResponseData, - RequestId, - RequestList, -} from "../common/request-types"; +import type { GrpcRequestData, GrpcResponse, HttpRequestData, RequestId, RequestList } from "../common/request-types"; import { backgroundColor } from "../renderer/src/palette"; -import { browseProtoRoot, cancelGrpcRequest, findProtoFiles, makeGrpcRequest } from "./Communication/grpc"; +import { + browseProtoRoot, + cancelGrpcRequest, + findProtoFiles, + getMethodsViaReflection, + makeGrpcRequest, +} from "./Communication/grpc"; import { cancelHttpRequest, makeHttpRequest } from "./Communication/http"; import { parseProtoFile } from "./Communication/proto"; import { exportDirectory, importDirectory } from "./Storage/import-export"; @@ -66,6 +65,14 @@ app.whenReady().then(async () => { return { result: "error", code: "EXCEPTION", detail: err.toString(), time: 0 }; } }); + ipcMain.handle(IpcCall.GrpcReflection, async (_, url: string) => { + try { + return await getMethodsViaReflection(url); + // biome-ignore lint/suspicious/noExplicitAny: + } catch (err: any) { + return { result: "error", code: "EXCEPTION", detail: err.toString(), time: 0 }; + } + }); ipcMain.handle(IpcCall.AbortRequest, (_, requestType: "http" | "grpc", requestId: RequestId) => { if (requestType === "http") { diff --git a/src/renderer/src/App.tsx b/src/renderer/src/App.tsx index 52493bb..bb9e59e 100644 --- a/src/renderer/src/App.tsx +++ b/src/renderer/src/App.tsx @@ -156,7 +156,11 @@ const AppContainer = observer(({ context }: { context: AppContext }) => { ) : ( {context.activeRequest.type === "grpc" && ( - + )} {context.activeRequest.type === "http" && } diff --git a/src/renderer/src/GrpcRequestPanel.tsx b/src/renderer/src/GrpcRequestPanel.tsx index 61c3aab..507c176 100644 --- a/src/renderer/src/GrpcRequestPanel.tsx +++ b/src/renderer/src/GrpcRequestPanel.tsx @@ -5,14 +5,16 @@ import { runInAction } from "mobx"; import { observer } from "mobx-react-lite"; import { useCallback, useEffect, useState } from "react"; import styled from "styled-components"; -import type { MethodInfo, ProtoContent } from "../../common/grpc"; +import type { MethodInfo, ProtoContent, ProtoService } from "../../common/grpc"; import { IpcCall } from "../../common/ipc"; import { type GrpcRequestData, GrpcRequestKind } from "../../common/request-types"; -import type { ProtoConfig } from "./AppContext"; +import type { AppContext, ProtoConfig } from "./AppContext"; import { type SelectProtoModalResult, SelectProtosModal } from "./modals/select-protos"; import { backgroundHoverColor, errorColor } from "./palette"; -import { debounce } from "./util/debounce"; import { defaultProtoBody, lintProtoJson } from "./util/proto-lint"; +import { substituteVariables } from "./util/substitute-variables"; +import Toggle from "./common-components/toggle"; +import { RefreshCcw } from "lucide-react"; const RequestPanelRoot = styled.div` display: flex; @@ -55,6 +57,8 @@ const GrpcMethodPopoverRoot = styled.div` right: anchor(right); top: anchor(bottom); width: auto; + border: none; + max-height: 70vh; `; const GrpcMethodPopoverEntry = styled.button` @@ -62,6 +66,8 @@ const GrpcMethodPopoverEntry = styled.button` border: unset; width: 100%; + margin-top: 1px; + cursor: pointer; &:hover { @@ -74,6 +80,36 @@ const ProtoErrorBox = styled.div` padding: 5px 10px; `; +const ReflectionHeader = styled.label` + display: flex; + align-items: center; + gap: 8px; +`; + +const UseReflectionText = styled.label` + font-size: 13px; + margin-top: -3px; +`; + +const ReflectedServicesText = styled.label` + font-size: 13px; + margin-top: -3px; + margin-left: auto; + color: #ccc; +`; + +const ReflectionButton = styled.button` + padding: 6px; + cursor: pointer; + border: none; + border-radius: 5px; + background-color: var(--color-background); + + &:hover { + background-color: ${backgroundHoverColor}; + } +`; + interface MethodDescriptor { service: string; method: MethodInfo; @@ -89,7 +125,11 @@ const codemirrorTheme = EditorView.theme({ }); export const GrpcRequestPanel = observer( - ({ activeRequest, protoConfig }: { activeRequest: GrpcRequestData; protoConfig: ProtoConfig }) => { + ({ + context, + activeRequest, + protoConfig, + }: { context: AppContext; activeRequest: GrpcRequestData; protoConfig: ProtoConfig }) => { const [protoModalOpen, setProtoModalOpen] = useState(false); const openProtoModal = useCallback(() => setProtoModalOpen(true), []); @@ -109,6 +149,7 @@ export const GrpcRequestPanel = observer( ); const [rpcs, setRpcs] = useState(undefined); + const [reflectedServices, setReflectedServices] = useState(undefined); const [protoError, setError] = useState(undefined); useEffect(() => { @@ -144,38 +185,74 @@ export const GrpcRequestPanel = observer( const selectMethod = useCallback( (method: MethodDescriptor) => { - activeRequest.rpc = { - service: method.service, - method: method.method.name, - }; - if (rpcs) { - const rpc = rpcs.find( - (rpc) => - rpc.service === activeRequest.rpc?.service && rpc.method.name === activeRequest.rpc.method, - ); - if (rpc?.method) { - if (!rpc.method.requestStream && !rpc.method.serverStream) { + runInAction(() => { + activeRequest.rpc = { + service: method.service, + method: method.method.name, + }; + if (activeRequest.useReflection && reflectedServices) { + const service = reflectedServices.find((service) => service.name === method.service); + if (!service) return; + const rpc = service.methods.find((m) => m.name === method.method.name); + if (!rpc) return; + + if (!rpc.requestStream && !rpc.serverStream) { activeRequest.kind = GrpcRequestKind.Unary; - } else if (!rpc.method.requestStream && rpc.method.serverStream) { + } else if (!rpc.requestStream && rpc.serverStream) { activeRequest.kind = GrpcRequestKind.ResponseStreaming; - } else if (rpc.method.requestStream && !rpc.method.serverStream) { + } else if (rpc.requestStream && !rpc.serverStream) { activeRequest.kind = GrpcRequestKind.RequestStreaming; - } else if (rpc.method.requestStream && rpc.method.serverStream) { + } else if (rpc.requestStream && rpc.serverStream) { activeRequest.kind = GrpcRequestKind.Bidirectional; } + + if (rpc?.requestType) { + activeRequest.body = defaultProtoBody(rpc.requestType).value; + } + } else if (rpcs) { + const rpc = rpcs.find( + (rpc) => + rpc.service === activeRequest.rpc?.service && + rpc.method.name === activeRequest.rpc.method, + ); + + if (rpc?.method) { + if (!rpc.method.requestStream && !rpc.method.serverStream) { + activeRequest.kind = GrpcRequestKind.Unary; + } else if (!rpc.method.requestStream && rpc.method.serverStream) { + activeRequest.kind = GrpcRequestKind.ResponseStreaming; + } else if (rpc.method.requestStream && !rpc.method.serverStream) { + activeRequest.kind = GrpcRequestKind.RequestStreaming; + } else if (rpc.method.requestStream && rpc.method.serverStream) { + activeRequest.kind = GrpcRequestKind.Bidirectional; + } + } + if (rpc?.method.requestType) { + activeRequest.body = defaultProtoBody(rpc.method.requestType).value; + } } - if (rpc?.method.requestType) { - activeRequest.body = defaultProtoBody(rpc.method.requestType).value; - } - } + }); }, - [activeRequest, rpcs], + [activeRequest, rpcs, reflectedServices], ); const activeRpc = rpcs?.find( (rpc) => rpc.service === activeRequest.rpc?.service && rpc.method.name === activeRequest.rpc.method, ); + const reflectedRpcs = () => { + const reflectionMethods: MethodDescriptor[] = []; + for (const service of reflectedServices ?? []) { + for (const method of service.methods) { + reflectionMethods.push({ + service: service.name, + method, + }); + } + } + return reflectionMethods; + }; + const linter = CodeMirrorLint.linter((view) => { if (!activeRpc?.method.requestType) return []; @@ -183,16 +260,66 @@ export const GrpcRequestPanel = observer( return lintProtoJson(content, activeRpc.method.requestType); }); + const fetchReflectionMethods = useCallback(async () => { + try { + const url = substituteVariables(activeRequest.url, context.substitutionVariables); + const result: Result = await window.electron.ipcRenderer.invoke( + IpcCall.GrpcReflection, + url, + ); + if (result.success) { + setReflectedServices(result.value); + setError(undefined); + } else { + setRpcs([]); + setError(`Error fetching reflection methods: ${result.error}`); + } + } catch (err) { + setRpcs([]); + setError(`Failed to fetch reflection methods: ${err}`); + } + }, [activeRequest.url, context.substitutionVariables]); + return ( - - - {activeRequest.protoFile - ? shortProtoPath(activeRequest.protoFile.protoPath) - : "Select proto file..."} - - + + + + runInAction(() => { + activeRequest.useReflection = v; + }) + } + /> + Use reflection + {activeRequest.useReflection && ( + <> + + {reflectedServices?.length ?? 0} services found + + + + + + )} + + + {!activeRequest.useReflection && ( + + {activeRequest.protoFile + ? shortProtoPath(activeRequest.protoFile.protoPath) + : "Select proto file..."} + + )} + {activeRequest.rpc ? activeRequest.rpc.method : "Select method..."} {protoError && {protoError}} @@ -221,13 +348,15 @@ function GrpcMethodPopover({ }: { rpcs: MethodDescriptor[]; onSelectMethod: (method: MethodDescriptor) => void }) { return ( - {rpcs.map((r, i) => ( - onSelectMethod(r)} - >{`${r.service} / ${r.method.name}`} - ))} + {rpcs.length === 0 + ? "No requests found..." + : rpcs.map((r, i) => ( + onSelectMethod(r)} + >{`${r.service} / ${r.method.name}`} + ))} ); } diff --git a/src/renderer/src/common-components/toggle.tsx b/src/renderer/src/common-components/toggle.tsx new file mode 100644 index 0000000..0d2909d --- /dev/null +++ b/src/renderer/src/common-components/toggle.tsx @@ -0,0 +1,72 @@ +import styled from "styled-components"; + +const Toggle = ({ checked, onChange }: { checked: boolean; onChange: (newValue: boolean) => void }) => { + return ( + + + + ); +}; + +const StyledWrapper = styled.div` + .switch { + --secondary-container: #3a4b39; + --primary: #84da89; + font-size: 11px; + position: relative; + display: inline-block; + width: 3.7em; + height: 1.8em; + } + + .switch input { + display: none; + opacity: 0; + width: 0; + height: 0; + } + + .slider { + position: absolute; + cursor: pointer; + top: 0; + left: 0; + right: 0; + bottom: 0; + background-color: #313033; + transition: .2s; + border-radius: 30px; + } + + .slider:before { + position: absolute; + content: ""; + height: 1.4em; + width: 1.4em; + border-radius: 20px; + left: 0.2em; + bottom: 0.2em; + background-color: #aeaaae; + transition: .4s; + } + + input:checked + .slider::before { + background-color: var(--primary); + } + + input:checked + .slider { + background-color: var(--secondary-container); + } + + input:focus + .slider { + box-shadow: 0 0 1px var(--secondary-container); + } + + input:checked + .slider:before { + transform: translateX(1.9em); + }`; + +export default Toggle; diff --git a/src/renderer/src/util/proto-lint.ts b/src/renderer/src/util/proto-lint.ts index 1a5f916..7d729f2 100644 --- a/src/renderer/src/util/proto-lint.ts +++ b/src/renderer/src/util/proto-lint.ts @@ -50,17 +50,23 @@ function lintMessage( diagnostics: CodeMirrorLint.Diagnostic[], ): void { const protoFields = Object.entries(protoDescriptor.fields); - const knownFields = new Map(protoFields); + const knownFields = new Map(); + for (const [name, field] of protoFields) { + knownFields.set(name, field?.type); + } const requiredFields = new Set( protoFields - .filter(([name, field]) => field?.type !== "optional" && field?.type !== "oneof") + .filter(([name, field]) => field?.type.type !== "optional" && field?.type.type !== "oneof") .map(([name, field]) => name), ); + const seenFields = new Map(); - const oneofs = protoFields.filter(([name, field]) => field?.type === "oneof") as Array<[string, ProtoOneOf]>; - for (const [_, oneof] of oneofs) { - for (const [name, type] of Object.entries(oneof.fields)) { - knownFields.set(name, type); + for (const [_, field] of protoFields) { + if (field === undefined) continue; + if (field.type.type !== "oneof") continue; + + for (const [name, type] of Object.entries(field.type.fields)) { + knownFields.set(name, type.type); } } @@ -93,7 +99,9 @@ function lintMessage( } // Check oneofs - for (const [oneOfName, oneof] of oneofs) { + for (const [oneOfName, field] of protoFields) { + if (field === undefined || field.type.type !== "oneof") continue; + const oneof = field.type; if (oneof) { // Count how many fields were seen const seen = []; @@ -226,15 +234,15 @@ export function defaultProtoBody(protoDescriptor: ProtoObject, indent = ""): { v let result = "{\n"; for (const [name, field] of Object.entries(protoDescriptor.fields)) { if (field) { - if (field.type === "oneof") { - const options = Object.entries(field.fields); + if (field.type.type === "oneof") { + const options = Object.entries(field.type.fields); if (options.length === 0) continue; - const { value, comments } = defaultProtoBody(field, indent + INDENT_STEP); + const { value, comments } = defaultProtoBody(field.type, indent + INDENT_STEP); const comment = comments && comments.length > 0 ? ` // ${comments.join(", ")}` : ""; result += `${indent}${INDENT_STEP}${options[0][0]}: ${value},${comment}\n`; } else { - const { value, comments } = defaultProtoBody(field, indent + INDENT_STEP); + const { value, comments } = defaultProtoBody(field.type, indent + INDENT_STEP); const comment = comments && comments.length > 0 ? ` // ${comments.join(", ")}` : ""; result += `${indent}${INDENT_STEP}${name}: ${value},${comment}\n`; } @@ -272,7 +280,7 @@ export function defaultProtoBody(protoDescriptor: ProtoObject, indent = ""): { v const [optionName, optionType] = options[0]; const comment = `oneof ${options.map(([name]) => name).join(", ")}`; - const { value, comments } = defaultProtoBody(optionType, indent + INDENT_STEP); + const { value, comments } = defaultProtoBody(optionType.type, indent + INDENT_STEP); return { value, comments: comments ? [comment, ...comments] : [comment] }; } return { value: "{}", comments: ["Empty oneof"] }; diff --git a/test/grpc-test-server.ts b/test/grpc-test-server.ts index 8461f00..af32ab1 100644 --- a/test/grpc-test-server.ts +++ b/test/grpc-test-server.ts @@ -1,6 +1,7 @@ // If you want you can run this test server with `npx tsx test/grpc-test-server.ts` import * as grpc from "@grpc/grpc-js"; +import * as grpc_reflection from "@grpc/reflection"; import * as proto from "@grpc/proto-loader"; import * as protobufjs from "protobufjs"; @@ -41,6 +42,11 @@ interface MessageWithBools { myoptionalbool?: boolean; } +interface MessageWithEnums { + globalEnum: number; + nestedEnum: number; +} + const server = new grpc.Server(); server.addService(greeterService.service, { SayHello: (call: grpc.ServerUnaryCall, callback: grpc.sendUnaryData) => { @@ -69,8 +75,17 @@ server.addService(greeterService.service, { } setTimeout(send, 3000); }, - TestNested: (call: grpc.ServerUnaryCall, callback: grpc.sendUnaryData) => { - callback(null, { message: JSON.stringify(call.request) }); + TestEnums: ( + call: grpc.ServerUnaryCall, + callback: grpc.sendUnaryData, + ) => { + callback(null, call.request); + }, + TestNested: ( + call: grpc.ServerUnaryCall, + callback: grpc.sendUnaryData, + ) => { + callback(null, call.request); }, TestGetStringList: ( call: grpc.ServerUnaryCall, @@ -107,6 +122,9 @@ server.addService(greeterService.service, { }, }); +const reflection = new grpc_reflection.ReflectionService(protoPackage); +reflection.addToServer(server); + // Start server, will block process server.bindAsync(SERVER_ADDRESS, grpc.ServerCredentials.createInsecure(), () => { console.log(`Server started at ${SERVER_ADDRESS}`); diff --git a/test/proto.spec.ts b/test/proto.spec.ts index 7ca3e25..d56263a 100644 --- a/test/proto.spec.ts +++ b/test/proto.spec.ts @@ -40,7 +40,11 @@ describe("myproto.proto", async () => { type: "message", name: "HelloRequest", fields: { - name: { type: "literal", literalType: "string" }, + name: { + id: 1, + name: "name", + type: { type: "literal", literalType: "string" }, + }, }, }); }); @@ -54,7 +58,11 @@ describe("myproto.proto", async () => { type: "message", name: "HelloReply", fields: { - message: { type: "literal", literalType: "string" }, + message: { + id: 1, + name: "message", + type: { type: "literal", literalType: "string" }, + }, }, }); }); @@ -67,7 +75,11 @@ describe("myproto.proto", async () => { type: "message", name: "HelloRequest", fields: { - name: { type: "literal", literalType: "string" }, + name: { + id: 1, + name: "name", + type: { type: "literal", literalType: "string" }, + }, }, }; @@ -75,7 +87,11 @@ describe("myproto.proto", async () => { type: "message", name: "HelloReply", fields: { - message: { type: "literal", literalType: "string" }, + message: { + id: 1, + name: "message", + type: { type: "literal", literalType: "string" }, + }, }, }; @@ -83,16 +99,40 @@ describe("myproto.proto", async () => { type: "message", name: "NestedRequest", fields: { - reply: helloReplyType, - request2: { type: "optional", optionalType: helloRequestType }, + reply: { + id: 1, + name: "reply", + type: helloReplyType, + }, + request2: { + id: 2, + name: "request2", + type: { type: "optional", optionalType: helloRequestType }, + }, testoneof: { - type: "oneof", - fields: { - reply3: helloReplyType, - request4: helloRequestType, + id: 0, + name: "testoneof", + type: { + type: "oneof", + fields: { + reply3: { + id: 3, + name: "reply3", + type: helloReplyType, + }, + request4: { + id: 4, + name: "request4", + type: helloRequestType, + }, + }, }, }, - replies: { type: "repeated", repeatedType: helloReplyType }, + replies: { + id: 5, + name: "replies", + type: { type: "repeated", repeatedType: helloReplyType }, + }, }, }); }); @@ -107,26 +147,34 @@ describe("myproto.proto", async () => { name: "MessageWithEnums", fields: { globalEnum: { - type: "enum", - name: "GlobalEnum", - values: [ - { name: "A", value: 0 }, - { name: "B", value: 1 }, - { name: "C", value: 2 }, - ], - }, - nestedEnum: { - type: "optional", - optionalType: { + id: 0, + name: "globalEnum", + type: { type: "enum", - name: "NestedEnum", + name: "GlobalEnum", values: [ - { name: "X", value: 0 }, - { name: "Y", value: 1 }, - { name: "Z", value: 2 }, + { name: "A", value: 0 }, + { name: "B", value: 1 }, + { name: "C", value: 2 }, ], }, }, + nestedEnum: { + id: 1, + name: "nestedEnum", + type: { + type: "optional", + optionalType: { + type: "enum", + name: "NestedEnum", + values: [ + { name: "X", value: 0 }, + { name: "Y", value: 1 }, + { name: "Z", value: 2 }, + ], + }, + }, + }, }, }); }); @@ -150,10 +198,18 @@ describe("sibling1/nestedproto1.proto", async () => { name: "MessageUsingParent", fields: { helloRequest: { - type: "message", - name: "HelloRequest", - fields: { - name: { type: "literal", literalType: "string" }, + id: 1, + name: "helloRequest", + type: { + type: "message", + name: "HelloRequest", + fields: { + name: { + id: 1, + name: "name", + type: { type: "literal", literalType: "string" }, + }, + }, }, }, }, @@ -163,17 +219,29 @@ describe("sibling1/nestedproto1.proto", async () => { name: "MessageUsingSibling", fields: { sibling: { - type: "message", - name: "SiblingMessage", - fields: { - oneofthese: { - type: "oneof", - fields: { - b: { type: "literal", literalType: "bool" }, - i: { type: "literal", literalType: "int32" }, + id: 1, + name: "sibling", + type: { + type: "message", + name: "SiblingMessage", + fields: { + oneofthese: { + id: 0, + name: "oneofthese", + type: { + type: "oneof", + fields: { + b: { id: 1, name: "b", type: { type: "literal", literalType: "bool" } }, + i: { id: 2, name: "i", type: { type: "literal", literalType: "int32" } }, + }, + }, + }, + doubles: { + id: 3, + name: "doubles", + type: { type: "repeated", repeatedType: { type: "literal", literalType: "double" } }, }, }, - doubles: { type: "repeated", repeatedType: { type: "literal", literalType: "double" } }, }, }, }, diff --git a/test/protos/myproto.proto b/test/protos/myproto.proto index 5075e3c..efa0c1e 100644 --- a/test/protos/myproto.proto +++ b/test/protos/myproto.proto @@ -12,11 +12,11 @@ service Greeter { rpc StreamHello(HelloRequest) returns (stream HelloReply); rpc BiDirectionalStream(stream HelloRequest) returns (stream HelloReply); - rpc TestNested(NestedRequest) returns (HelloReply); + rpc TestNested(NestedRequest) returns (NestedRequest); rpc TestGetStringList(HelloRequest) returns (StringListReply); - rpc TestEnums(MessageWithEnums) returns (HelloReply); + rpc TestEnums(MessageWithEnums) returns (MessageWithEnums); rpc ErrorWithTrailers(HelloRequest) returns (HelloReply);