diff --git a/pkg/api/pagination.go b/pkg/api/pagination.go new file mode 100644 index 00000000..7afa3e1e --- /dev/null +++ b/pkg/api/pagination.go @@ -0,0 +1,72 @@ +package api + +import ( + "fmt" + "strconv" + + "github.com/docker/cagent/pkg/session" +) + +type PaginationParams struct { + Limit int + Before string +} + +const DefaultLimit = 50 + +const MaxLimit = 200 + +func PaginateMessages(messages []session.Message, params PaginationParams) ([]session.Message, *PaginationMetadata, error) { + totalCount := len(messages) + + limit := params.Limit + if limit <= 0 { + limit = DefaultLimit + } + if limit > MaxLimit { + limit = MaxLimit + } + + var beforeIndex int + var err error + + if params.Before != "" { + beforeIndex, err = strconv.Atoi(params.Before) + if err != nil { + return nil, nil, fmt.Errorf("invalid before cursor: %w", err) + } + } + + startIdx := 0 + var endIdx int + + if params.Before != "" { + endIdx = beforeIndex + if endIdx <= 0 { + return []session.Message{}, &PaginationMetadata{ + TotalMessages: totalCount, + Limit: 0, + }, nil + } + actualStart := max(endIdx-limit, startIdx) + startIdx = actualStart + } else { + actualStart := max(totalCount-limit, 0) + startIdx = actualStart + endIdx = totalCount + } + + paginatedMessages := messages[startIdx:endIdx] + + metadata := &PaginationMetadata{ + TotalMessages: totalCount, + Limit: len(paginatedMessages), + } + + // Only set cursor if there are more (older) messages available + if len(paginatedMessages) > 0 && startIdx > 0 { + metadata.PrevCursor = strconv.Itoa(startIdx) + } + + return paginatedMessages, metadata, nil +} diff --git a/pkg/api/pagination_test.go b/pkg/api/pagination_test.go new file mode 100644 index 00000000..c52b9b20 --- /dev/null +++ b/pkg/api/pagination_test.go @@ -0,0 +1,207 @@ +package api + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/session" +) + +func createTestMessages(count int) []session.Message { + messages := make([]session.Message, count) + for i := range count { + role := chat.MessageRoleUser + if i%2 == 1 { + role = chat.MessageRoleAssistant + } + messages[i] = session.Message{ + AgentFilename: "test.yaml", + AgentName: "test", + Message: chat.Message{ + Role: role, + Content: "Message " + strconv.Itoa(i), + CreatedAt: time.Now().Add(time.Duration(i) * time.Second).Format(time.RFC3339), + }, + } + } + return messages +} + +func TestPaginateMessages_FirstPage(t *testing.T) { + messages := createTestMessages(100) + + params := PaginationParams{ + Limit: 10, + } + + paginated, meta, err := PaginateMessages(messages, params) + require.NoError(t, err) + assert.Len(t, paginated, 10) + assert.Equal(t, 100, meta.TotalMessages) + assert.Equal(t, 10, meta.Limit) + assert.NotEmpty(t, meta.PrevCursor) // More older messages available + + // Should get most recent 10 messages (for chat infinite scroll) + // For 100 messages, indices 90-99 should be returned + // Check that we got recent messages by verifying they're different from the old first messages + assert.NotEqual(t, "Message 0", paginated[0].Message.Content) // Not the oldest message + assert.NotEqual(t, "Message 9", paginated[9].Message.Content) // Not the 10th oldest message + assert.Equal(t, "Message 90", paginated[0].Message.Content) // Index 90 + assert.Equal(t, "Message 99", paginated[9].Message.Content) // Index 99 +} + +func TestPaginateMessages_WithBeforeCursorPagination(t *testing.T) { + messages := createTestMessages(20) // Use smaller dataset for easier debugging + + // Start with a page at the end (messages 10-19) + endPageParams := PaginationParams{ + Limit: 10, + Before: "20", // Get 10 messages before index 20 (which should give us 10-19) + } + endPage, endMeta, err := PaginateMessages(messages, endPageParams) + require.NoError(t, err) + + // Verify we got the end page + assert.Len(t, endPage, 10) + assert.Equal(t, "Message 10", endPage[0].Message.Content) // Index 10 + assert.Equal(t, "Message 19", endPage[9].Message.Content) // Index 19 + + // Get previous page using before cursor (should give us messages 0-9) + prevPageParams := PaginationParams{ + Limit: 10, + Before: endMeta.PrevCursor, // Before the end page + } + prevPage, prevMeta, err := PaginateMessages(messages, prevPageParams) + require.NoError(t, err) + + assert.Len(t, prevPage, 10) + assert.Empty(t, prevMeta.PrevCursor) // No more older messages + + // Should get messages 0-9 + assert.Equal(t, "Message 0", prevPage[0].Message.Content) // Index 0 + assert.Equal(t, "Message 9", prevPage[9].Message.Content) // Index 9 + + // No overlap between pages + assert.NotEqual(t, endPage[0].Message.Content, prevPage[9].Message.Content) +} + +func TestPaginateMessages_WithBeforeCursor(t *testing.T) { + messages := createTestMessages(100) + + // Get a page in the middle (starting at index 50) + middleCursor := strconv.Itoa(50) + + params := PaginationParams{ + Limit: 10, + Before: middleCursor, + } + + paginated, meta, err := PaginateMessages(messages, params) + require.NoError(t, err) + + assert.Len(t, paginated, 10) + assert.NotEmpty(t, meta.PrevCursor) // There are older messages + + // Should get 10 messages before index 50 (indices 40-49) + assert.Equal(t, "Message "+strconv.Itoa(40), paginated[0].Message.Content) + assert.Equal(t, "Message "+strconv.Itoa(49), paginated[9].Message.Content) +} + +func TestPaginateMessages_DefaultLimit(t *testing.T) { + messages := createTestMessages(100) + + params := PaginationParams{ + Limit: 0, // Should use default + } + + paginated, meta, err := PaginateMessages(messages, params) + require.NoError(t, err) + + assert.Len(t, paginated, DefaultLimit) + assert.Equal(t, DefaultLimit, meta.Limit) +} + +func TestPaginateMessages_MaxLimit(t *testing.T) { + messages := createTestMessages(300) + + params := PaginationParams{ + Limit: 500, // Should be capped at MaxLimit + } + + paginated, meta, err := PaginateMessages(messages, params) + require.NoError(t, err) + + assert.Len(t, paginated, MaxLimit) + assert.Equal(t, MaxLimit, meta.Limit) +} + +func TestPaginateMessages_EmptyMessages(t *testing.T) { + messages := []session.Message{} + + params := PaginationParams{ + Limit: 10, + } + + paginated, meta, err := PaginateMessages(messages, params) + require.NoError(t, err) + + assert.Empty(t, paginated) + assert.Equal(t, 0, meta.TotalMessages) + assert.Empty(t, meta.PrevCursor) // No messages at all +} + +func TestPaginateMessages_LastPage(t *testing.T) { + messages := createTestMessages(25) + + // Get the oldest 5 messages (using before cursor to limit to earliest messages) + lastPageParams := PaginationParams{ + Limit: 10, + Before: "5", // Before the 6th message (index 5) + } + lastPage, lastMeta, err := PaginateMessages(messages, lastPageParams) + require.NoError(t, err) + + assert.Len(t, lastPage, 5) // Only 5 messages (0-4) + assert.Empty(t, lastMeta.PrevCursor) // No more older messages + assert.Equal(t, 25, lastMeta.TotalMessages) + + // Should get the first 5 messages + assert.Equal(t, "Message 0", lastPage[0].Message.Content) + assert.Equal(t, "Message 4", lastPage[4].Message.Content) +} + +func TestPaginateMessages_BeforeFirstMessage(t *testing.T) { + messages := createTestMessages(10) + + // Create cursor pointing to before first message + firstCursor := strconv.Itoa(0) + + params := PaginationParams{ + Limit: 10, + Before: firstCursor, + } + + paginated, meta, err := PaginateMessages(messages, params) + require.NoError(t, err) + + assert.Empty(t, paginated) + assert.Empty(t, meta.PrevCursor) // No messages at all +} + +func TestPaginateMessages_InvalidCursor(t *testing.T) { + messages := createTestMessages(10) + + params := PaginationParams{ + Limit: 10, + Before: "invalid-cursor", + } + + _, _, err := PaginateMessages(messages, params) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid before cursor") +} diff --git a/pkg/api/types.go b/pkg/api/types.go index 2e2882de..897035a1 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -126,14 +126,22 @@ type SessionsResponse struct { // SessionResponse represents a detailed session type SessionResponse struct { - ID string `json:"id"` - Title string `json:"title"` - Messages []session.Message `json:"messages,omitempty"` - CreatedAt time.Time `json:"created_at"` - ToolsApproved bool `json:"tools_approved"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - WorkingDir string `json:"working_dir,omitempty"` + ID string `json:"id"` + Title string `json:"title"` + Messages []session.Message `json:"messages,omitempty"` + CreatedAt time.Time `json:"created_at"` + ToolsApproved bool `json:"tools_approved"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + WorkingDir string `json:"working_dir,omitempty"` + Pagination *PaginationMetadata `json:"pagination,omitempty"` +} + +// PaginationMetadata contains pagination information +type PaginationMetadata struct { + TotalMessages int `json:"total_messages"` // Total number of messages in session + Limit int `json:"limit"` // Number of messages in this response + PrevCursor string `json:"prev_cursor,omitempty"` // Cursor for previous page (empty if no more messages) } // ResumeSessionRequest represents a request to resume a session diff --git a/pkg/server/server.go b/pkg/server/server.go index 6bbeffe8..c5e56ca9 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -13,6 +13,7 @@ import ( "os" "path/filepath" "sort" + "strconv" "strings" "sync" "time" @@ -931,15 +932,35 @@ func (s *Server) getSession(c echo.Context) error { return echo.NewHTTPError(http.StatusNotFound, "session not found") } + params := api.PaginationParams{ + Limit: api.DefaultLimit, + Before: c.QueryParam("before"), + } + + if limitStr := c.QueryParam("limit"); limitStr != "" { + if limit, err := strconv.Atoi(limitStr); err == nil && limit > 0 { + params.Limit = limit + } + } + + allMessages := sess.GetAllMessages() + + paginatedMessages, pagination, err := api.PaginateMessages(allMessages, params) + if err != nil { + slog.Error("Failed to paginate messages", "error", err) + return echo.NewHTTPError(http.StatusBadRequest, "invalid pagination parameters: "+err.Error()) + } + sr := api.SessionResponse{ ID: sess.ID, Title: sess.Title, CreatedAt: sess.CreatedAt, - Messages: sess.GetAllMessages(), + Messages: paginatedMessages, ToolsApproved: sess.ToolsApproved, InputTokens: sess.InputTokens, OutputTokens: sess.OutputTokens, WorkingDir: sess.WorkingDir, + Pagination: pagination, } return c.JSON(http.StatusOK, sr)