@@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
203203}
204204
205205// UpdatePullRequest creates a tool to update an existing pull request.
206- func UpdatePullRequest (getClient GetClientFn , t translations.TranslationHelperFunc ) (mcp.Tool , server.ToolHandlerFunc ) {
206+ func UpdatePullRequest (getClient GetClientFn , getGQLClient GetGQLClientFn , t translations.TranslationHelperFunc ) (mcp.Tool , server.ToolHandlerFunc ) {
207207 return mcp .NewTool ("update_pull_request" ,
208208 mcp .WithDescription (t ("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION" , "Update an existing pull request in a GitHub repository." )),
209209 mcp .WithToolAnnotation (mcp.ToolAnnotation {
@@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
232232 mcp .Description ("New state" ),
233233 mcp .Enum ("open" , "closed" ),
234234 ),
235+ mcp .WithBoolean ("draft" ,
236+ mcp .Description ("Mark pull request as draft (true) or ready for review (false)" ),
237+ ),
235238 mcp .WithString ("base" ,
236239 mcp .Description ("New base branch name" ),
237240 ),
@@ -253,74 +256,165 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
253256 return mcp .NewToolResultError (err .Error ()), nil
254257 }
255258
256- // Build the update struct only with provided fields
259+ draftProvided := request .GetArguments ()["draft" ] != nil
260+ var draftValue bool
261+ if draftProvided {
262+ draftValue , err = OptionalParam [bool ](request , "draft" )
263+ if err != nil {
264+ return nil , err
265+ }
266+ }
267+
257268 update := & github.PullRequest {}
258- updateNeeded := false
269+ restUpdateNeeded := false
259270
260271 if title , ok , err := OptionalParamOK [string ](request , "title" ); err != nil {
261272 return mcp .NewToolResultError (err .Error ()), nil
262273 } else if ok {
263274 update .Title = github .Ptr (title )
264- updateNeeded = true
275+ restUpdateNeeded = true
265276 }
266277
267278 if body , ok , err := OptionalParamOK [string ](request , "body" ); err != nil {
268279 return mcp .NewToolResultError (err .Error ()), nil
269280 } else if ok {
270281 update .Body = github .Ptr (body )
271- updateNeeded = true
282+ restUpdateNeeded = true
272283 }
273284
274285 if state , ok , err := OptionalParamOK [string ](request , "state" ); err != nil {
275286 return mcp .NewToolResultError (err .Error ()), nil
276287 } else if ok {
277288 update .State = github .Ptr (state )
278- updateNeeded = true
289+ restUpdateNeeded = true
279290 }
280291
281292 if base , ok , err := OptionalParamOK [string ](request , "base" ); err != nil {
282293 return mcp .NewToolResultError (err .Error ()), nil
283294 } else if ok {
284295 update .Base = & github.PullRequestBranch {Ref : github .Ptr (base )}
285- updateNeeded = true
296+ restUpdateNeeded = true
286297 }
287298
288299 if maintainerCanModify , ok , err := OptionalParamOK [bool ](request , "maintainer_can_modify" ); err != nil {
289300 return mcp .NewToolResultError (err .Error ()), nil
290301 } else if ok {
291302 update .MaintainerCanModify = github .Ptr (maintainerCanModify )
292- updateNeeded = true
303+ restUpdateNeeded = true
293304 }
294305
295- if ! updateNeeded {
306+ if ! restUpdateNeeded && ! draftProvided {
296307 return mcp .NewToolResultError ("No update parameters provided." ), nil
297308 }
298309
310+ if restUpdateNeeded {
311+ client , err := getClient (ctx )
312+ if err != nil {
313+ return nil , fmt .Errorf ("failed to get GitHub client: %w" , err )
314+ }
315+
316+ _ , resp , err := client .PullRequests .Edit (ctx , owner , repo , pullNumber , update )
317+ if err != nil {
318+ return ghErrors .NewGitHubAPIErrorResponse (ctx ,
319+ "failed to update pull request" ,
320+ resp ,
321+ err ,
322+ ), nil
323+ }
324+ defer func () { _ = resp .Body .Close () }()
325+
326+ if resp .StatusCode != http .StatusOK {
327+ body , err := io .ReadAll (resp .Body )
328+ if err != nil {
329+ return nil , fmt .Errorf ("failed to read response body: %w" , err )
330+ }
331+ return mcp .NewToolResultError (fmt .Sprintf ("failed to update pull request: %s" , string (body ))), nil
332+ }
333+ }
334+
335+ if draftProvided {
336+ gqlClient , err := getGQLClient (ctx )
337+ if err != nil {
338+ return nil , fmt .Errorf ("failed to get GitHub GraphQL client: %w" , err )
339+ }
340+
341+ var prQuery struct {
342+ Repository struct {
343+ PullRequest struct {
344+ ID githubv4.ID
345+ IsDraft githubv4.Boolean
346+ } `graphql:"pullRequest(number: $prNum)"`
347+ } `graphql:"repository(owner: $owner, name: $repo)"`
348+ }
349+
350+ err = gqlClient .Query (ctx , & prQuery , map [string ]interface {}{
351+ "owner" : githubv4 .String (owner ),
352+ "repo" : githubv4 .String (repo ),
353+ "prNum" : githubv4 .Int (pullNumber ), // #nosec G115 - pull request numbers are always small positive integers
354+ })
355+ if err != nil {
356+ return ghErrors .NewGitHubGraphQLErrorResponse (ctx , "Failed to find pull request" , err ), nil
357+ }
358+
359+ currentIsDraft := bool (prQuery .Repository .PullRequest .IsDraft )
360+
361+ if currentIsDraft != draftValue {
362+ if draftValue {
363+ // Convert to draft
364+ var mutation struct {
365+ ConvertPullRequestToDraft struct {
366+ PullRequest struct {
367+ ID githubv4.ID
368+ IsDraft githubv4.Boolean
369+ }
370+ } `graphql:"convertPullRequestToDraft(input: $input)"`
371+ }
372+
373+ err = gqlClient .Mutate (ctx , & mutation , githubv4.ConvertPullRequestToDraftInput {
374+ PullRequestID : prQuery .Repository .PullRequest .ID ,
375+ }, nil )
376+ if err != nil {
377+ return ghErrors .NewGitHubGraphQLErrorResponse (ctx , "Failed to convert pull request to draft" , err ), nil
378+ }
379+ } else {
380+ // Mark as ready for review
381+ var mutation struct {
382+ MarkPullRequestReadyForReview struct {
383+ PullRequest struct {
384+ ID githubv4.ID
385+ IsDraft githubv4.Boolean
386+ }
387+ } `graphql:"markPullRequestReadyForReview(input: $input)"`
388+ }
389+
390+ err = gqlClient .Mutate (ctx , & mutation , githubv4.MarkPullRequestReadyForReviewInput {
391+ PullRequestID : prQuery .Repository .PullRequest .ID ,
392+ }, nil )
393+ if err != nil {
394+ return ghErrors .NewGitHubGraphQLErrorResponse (ctx , "Failed to mark pull request ready for review" , err ), nil
395+ }
396+ }
397+ }
398+ }
399+
299400 client , err := getClient (ctx )
300401 if err != nil {
301- return nil , fmt . Errorf ( "failed to get GitHub client: %w" , err )
402+ return nil , err
302403 }
303- pr , resp , err := client .PullRequests .Edit (ctx , owner , repo , pullNumber , update )
404+
405+ finalPR , resp , err := client .PullRequests .Get (ctx , owner , repo , pullNumber )
304406 if err != nil {
305- return ghErrors .NewGitHubAPIErrorResponse (ctx ,
306- "failed to update pull request" ,
307- resp ,
308- err ,
309- ), nil
407+ return ghErrors .NewGitHubAPIErrorResponse (ctx , "Failed to get pull request" , resp , err ), nil
310408 }
311- defer func () { _ = resp .Body .Close () }()
312-
313- if resp .StatusCode != http .StatusOK {
314- body , err := io .ReadAll (resp .Body )
315- if err != nil {
316- return nil , fmt .Errorf ("failed to read response body: %w" , err )
409+ defer func () {
410+ if resp != nil && resp .Body != nil {
411+ _ = resp .Body .Close ()
317412 }
318- return mcp .NewToolResultError (fmt .Sprintf ("failed to update pull request: %s" , string (body ))), nil
319- }
413+ }()
320414
321- r , err := json .Marshal (pr )
415+ r , err := json .Marshal (finalPR )
322416 if err != nil {
323- return nil , fmt .Errorf ( "failed to marshal response: %w " , err )
417+ return mcp . NewToolResultError ( fmt .Sprintf ( "Failed to marshal response: %v " , err )), nil
324418 }
325419
326420 return mcp .NewToolResultText (string (r )), nil
0 commit comments