From e1f7cff16164c06636f9639b07e351c777fa264a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 18 Oct 2025 17:57:52 +0000 Subject: [PATCH 01/10] Initial plan From e2f9479b10e714f35c91ad5d46cc32c97f57f04b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 18 Oct 2025 18:08:57 +0000 Subject: [PATCH 02/10] Add POST request tests for HTTP client - Add TestHTTPClientPOSTWithTextPayload to test POST with plain text - Add TestHTTPClientPOSTWithJSONPayload to test POST with JSON data - Add TestHTTPClientPOSTWithFormData to test POST with form-encoded data - All tests verify request body is captured correctly in WARC files - Add missing imports (strconv, net/url) for new test functions Co-authored-by: CorentinB <5089772+CorentinB@users.noreply.github.com> --- client_test.go | 371 ++++++++++++++++++++++++++++++++++++++++++++++ dedupe.go | 4 +- gzip_interface.go | 2 +- main_test.go | 2 +- write.go | 2 +- 5 files changed, 376 insertions(+), 5 deletions(-) diff --git a/client_test.go b/client_test.go index e0e1d3c..669c989 100644 --- a/client_test.go +++ b/client_test.go @@ -13,9 +13,11 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "os" "path" "path/filepath" + "strconv" "strings" "sync" "testing" @@ -1785,6 +1787,375 @@ func TestHTTPClientWithIPv6Disabled(t *testing.T) { } } +func TestHTTPClientPOSTWithTextPayload(t *testing.T) { + var ( + rotatorSettings = defaultRotatorSettings(t) + err error + ) + + // Create a test server that expects POST requests and echoes back the received body + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed to read request body: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("Received: ")) + w.Write(body) + })) + defer server.Close() + + // Initialize the WARC-writing HTTP client + httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{RotatorSettings: rotatorSettings}) + if err != nil { + t.Fatalf("Unable to init WARC writing HTTP client: %s", err) + } + waitForErrors := drainErrChan(t, httpClient.ErrChan) + + // Create a POST request with a text payload + requestBody := strings.NewReader("Hello from POST request") + req, err := http.NewRequest("POST", server.URL, requestBody) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "text/plain") + + resp, err := httpClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + io.Copy(io.Discard, resp.Body) + + httpClient.Close() + waitForErrors() + + files, err := filepath.Glob(rotatorSettings.OutputDirectory + "/*") + if err != nil { + t.Fatal(err) + } + + // Verify the WARC file was created + if len(files) == 0 { + t.Fatal("No WARC files were created") + } + + // Check the WARC records contain the POST request and response + for _, path := range files { + testFileHash(t, path) + + file, err := os.Open(path) + if err != nil { + t.Fatalf("failed to open %q: %v", path, err) + } + defer file.Close() + + reader, err := NewReader(file) + if err != nil { + t.Fatalf("warc.NewReader failed for %q: %v", path, err) + } + + foundRequest := false + foundResponse := false + + for { + record, err := reader.ReadRecord() + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("warc.ReadRecord failed: %v", err) + } + + // Check for request record + if record.Header.Get("WARC-Type") == "request" { + foundRequest = true + record.Content.Seek(0, 0) + content, _ := io.ReadAll(record.Content) + contentStr := string(content) + + // Verify it's a POST request + if !strings.Contains(contentStr, "POST") { + t.Errorf("Request record does not contain POST method") + } + + // Verify the request body is present + if !strings.Contains(contentStr, "Hello from POST request") { + t.Errorf("Request record does not contain the expected request body") + } + } + + // Check for response record + if record.Header.Get("WARC-Type") == "response" { + foundResponse = true + } + + record.Content.Close() + } + + if !foundRequest { + t.Error("No request record found in WARC file") + } + if !foundResponse { + t.Error("No response record found in WARC file") + } + } +} + +func TestHTTPClientPOSTWithJSONPayload(t *testing.T) { + var ( + rotatorSettings = defaultRotatorSettings(t) + err error + ) + + // Create a test server that expects POST requests with JSON + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type: application/json, got %s", r.Header.Get("Content-Type")) + } + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed to read request body: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"status":"success","received":`)) + w.Write(body) + w.Write([]byte(`}`)) + })) + defer server.Close() + + // Initialize the WARC-writing HTTP client + httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{RotatorSettings: rotatorSettings}) + if err != nil { + t.Fatalf("Unable to init WARC writing HTTP client: %s", err) + } + waitForErrors := drainErrChan(t, httpClient.ErrChan) + + // Create a POST request with a JSON payload + jsonPayload := `{"name":"test","value":123}` + requestBody := strings.NewReader(jsonPayload) + req, err := http.NewRequest("POST", server.URL, requestBody) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + io.Copy(io.Discard, resp.Body) + + httpClient.Close() + waitForErrors() + + files, err := filepath.Glob(rotatorSettings.OutputDirectory + "/*") + if err != nil { + t.Fatal(err) + } + + // Verify the WARC file was created + if len(files) == 0 { + t.Fatal("No WARC files were created") + } + + // Check the WARC records contain the POST request with JSON body + for _, path := range files { + testFileHash(t, path) + + file, err := os.Open(path) + if err != nil { + t.Fatalf("failed to open %q: %v", path, err) + } + defer file.Close() + + reader, err := NewReader(file) + if err != nil { + t.Fatalf("warc.NewReader failed for %q: %v", path, err) + } + + foundJSONRequest := false + + for { + record, err := reader.ReadRecord() + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("warc.ReadRecord failed: %v", err) + } + + // Check for request record + if record.Header.Get("WARC-Type") == "request" { + record.Content.Seek(0, 0) + content, _ := io.ReadAll(record.Content) + contentStr := string(content) + + // Verify it's a POST request + if !strings.Contains(contentStr, "POST") { + t.Errorf("Request record does not contain POST method") + } + + // Verify the JSON payload is present + if strings.Contains(contentStr, jsonPayload) { + foundJSONRequest = true + } + } + + record.Content.Close() + } + + if !foundJSONRequest { + t.Error("JSON payload not found in request record") + } + } +} + +func TestHTTPClientPOSTWithFormData(t *testing.T) { + var ( + rotatorSettings = defaultRotatorSettings(t) + err error + ) + + // Create a test server that expects POST requests with form data + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + err := r.ParseForm() + if err != nil { + t.Errorf("Failed to parse form: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + username := r.FormValue("username") + password := r.FormValue("password") + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("Login attempt for user: " + username + " (password length: " + strconv.Itoa(len(password)) + ")")) + })) + defer server.Close() + + // Initialize the WARC-writing HTTP client + httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{RotatorSettings: rotatorSettings}) + if err != nil { + t.Fatalf("Unable to init WARC writing HTTP client: %s", err) + } + waitForErrors := drainErrChan(t, httpClient.ErrChan) + + // Create a POST request with form data + formData := url.Values{} + formData.Set("username", "testuser") + formData.Set("password", "testpass123") + + req, err := http.NewRequest("POST", server.URL, strings.NewReader(formData.Encode())) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := httpClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + io.Copy(io.Discard, resp.Body) + + httpClient.Close() + waitForErrors() + + files, err := filepath.Glob(rotatorSettings.OutputDirectory + "/*") + if err != nil { + t.Fatal(err) + } + + // Verify the WARC file was created + if len(files) == 0 { + t.Fatal("No WARC files were created") + } + + // Check the WARC records contain the POST request with form data + for _, path := range files { + testFileHash(t, path) + + file, err := os.Open(path) + if err != nil { + t.Fatalf("failed to open %q: %v", path, err) + } + defer file.Close() + + reader, err := NewReader(file) + if err != nil { + t.Fatalf("warc.NewReader failed for %q: %v", path, err) + } + + foundFormRequest := false + + for { + record, err := reader.ReadRecord() + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("warc.ReadRecord failed: %v", err) + } + + // Check for request record + if record.Header.Get("WARC-Type") == "request" { + record.Content.Seek(0, 0) + content, _ := io.ReadAll(record.Content) + contentStr := string(content) + + // Verify it's a POST request + if !strings.Contains(contentStr, "POST") { + t.Errorf("Request record does not contain POST method") + } + + // Verify the form data is present (URL-encoded) + if strings.Contains(contentStr, "username=testuser") && strings.Contains(contentStr, "password=testpass123") { + foundFormRequest = true + } + } + + record.Content.Close() + } + + if !foundFormRequest { + t.Error("Form data not found in request record") + } + } +} + // MARK: Benchmarks func BenchmarkConcurrentUnder2MB(b *testing.B) { var ( diff --git a/dedupe.go b/dedupe.go index ba1b36b..c93cb94 100644 --- a/dedupe.go +++ b/dedupe.go @@ -51,7 +51,7 @@ func (d *customDialer) checkLocalRevisit(digest string) revisitRecord { func checkCDXRevisit(CDXURL string, digest string, targetURI string, cookie string) (revisitRecord, error) { // CDX expects no hash header. For now we need to strip it. digest = strings.SplitN(digest, ":", 2)[1] - + req, err := http.NewRequest("GET", CDXURL+"/web/timemap/cdx?url="+url.QueryEscape(targetURI)+"&limit=-1", nil) if err != nil { return revisitRecord{}, err @@ -95,7 +95,7 @@ func checkCDXRevisit(CDXURL string, digest string, targetURI string, cookie stri func checkDoppelgangerRevisit(DoppelgangerHost string, digest string, targetURI string) (revisitRecord, error) { // Doppelganger is not expecting a hash header either but this will all be rewritten ... shortly... digest = strings.SplitN(digest, ":", 2)[1] - + req, err := http.NewRequest("GET", DoppelgangerHost+"/api/records/"+digest+"?uri="+targetURI, nil) if err != nil { return revisitRecord{}, err diff --git a/gzip_interface.go b/gzip_interface.go index 4ca2fa7..84ceb0f 100644 --- a/gzip_interface.go +++ b/gzip_interface.go @@ -19,4 +19,4 @@ type GzipReaderInterface interface { io.ReadCloser Multistream(enable bool) Reset(r io.Reader) error -} \ No newline at end of file +} diff --git a/main_test.go b/main_test.go index 750793c..4031b0d 100644 --- a/main_test.go +++ b/main_test.go @@ -8,5 +8,5 @@ import ( // Verify leaks in ALL package tests. func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) + goleak.VerifyTestMain(m) } diff --git a/write.go b/write.go index a173856..8073672 100644 --- a/write.go +++ b/write.go @@ -8,8 +8,8 @@ import ( "strings" "time" - "github.com/internetarchive/gowarc/pkg/spooledtempfile" "github.com/google/uuid" + "github.com/internetarchive/gowarc/pkg/spooledtempfile" "github.com/klauspost/compress/zstd" ) From 884c796bf6c5e37cb378bb0493d86c6c2fb84025 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 31 Oct 2025 20:52:13 -0400 Subject: [PATCH 03/10] build(deps): bump the go-modules group with 2 updates (#153) Bumps the go-modules group with 2 updates: [github.com/klauspost/compress](https://github.com/klauspost/compress) and [github.com/refraction-networking/utls](https://github.com/refraction-networking/utls). Updates `github.com/klauspost/compress` from 1.18.0 to 1.18.1 - [Release notes](https://github.com/klauspost/compress/releases) - [Changelog](https://github.com/klauspost/compress/blob/master/.goreleaser.yml) - [Commits](https://github.com/klauspost/compress/compare/v1.18.0...v1.18.1) Updates `github.com/refraction-networking/utls` from 1.8.0 to 1.8.1 - [Release notes](https://github.com/refraction-networking/utls/releases) - [Commits](https://github.com/refraction-networking/utls/compare/v1.8.0...v1.8.1) --- updated-dependencies: - dependency-name: github.com/klauspost/compress dependency-version: 1.18.1 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: go-modules - dependency-name: github.com/refraction-networking/utls dependency-version: 1.8.1 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: go-modules ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 3f3e97b..6c45986 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,10 @@ go 1.24.2 require ( github.com/google/uuid v1.6.0 - github.com/klauspost/compress v1.18.0 + github.com/klauspost/compress v1.18.1 github.com/maypok86/otter v1.2.4 github.com/miekg/dns v1.1.68 - github.com/refraction-networking/utls v1.8.0 + github.com/refraction-networking/utls v1.8.1 github.com/remeh/sizedwaitgroup v1.0.0 github.com/spf13/cobra v1.10.1 github.com/things-go/go-socks5 v0.1.0 diff --git a/go.sum b/go.sum index c728fca..d481126 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= -github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= +github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/maypok86/otter v1.2.4 h1:HhW1Pq6VdJkmWwcZZq19BlEQkHtI8xgsQzBVXJU0nfc= @@ -23,8 +23,8 @@ github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/refraction-networking/utls v1.8.0 h1:L38krhiTAyj9EeiQQa2sg+hYb4qwLCqdMcpZrRfbONE= -github.com/refraction-networking/utls v1.8.0/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= +github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= +github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remeh/sizedwaitgroup v1.0.0 h1:VNGGFwNo/R5+MJBf6yrsr110p0m4/OX4S3DCy7Kyl5E= github.com/remeh/sizedwaitgroup v1.0.0/go.mod h1:3j2R4OIe/SeS6YDhICBy22RWjJC5eNCJ1V+9+NVNYlo= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= From 7e3810ef2bab5abcffc88a45d979a0a4990db0e1 Mon Sep 17 00:00:00 2001 From: yzqzss <30341059+yzqzss@users.noreply.github.com> Date: Mon, 3 Nov 2025 18:22:04 +0800 Subject: [PATCH 04/10] fix: the value of `warcVer` changed unexpectedly (#154) The lifecycle of warcVer spans multiple `readUntilDelim()` calls, and changes in the underlying buffer (intermediateBuf) data can cause unexpected to the warcVer. --- read.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/read.go b/read.go index 478ae9c..6e179d0 100644 --- a/read.go +++ b/read.go @@ -254,7 +254,8 @@ func (r *Reader) ReadRecord(opts ...ReadOpts) (*Record, error) { } } - warcVer, _, err := readUntilDelim(r.bufReader, []byte("\r\n")) + _warcVer, _, err := readUntilDelim(r.bufReader, []byte("\r\n")) + warcVer := string(_warcVer) // clone to avoid changes in underlying buffer if err != nil { if err == io.EOF && len(warcVer) == 0 { // treat as EOF for safety if member present but empty From cadb7b62773f5348b5f0654338ac350284fe67bb Mon Sep 17 00:00:00 2001 From: Corentin Barreau Date: Tue, 4 Nov 2025 21:31:15 +0100 Subject: [PATCH 05/10] Add platform-specific memory management for macOS to spooledtempfile (#150) * Add platform-specific memory management for macOS to spooledtempfile This commit adds macOS support to the spooledtempfile package's memory management functionality, which previously only supported Linux. Changes: - Split memory.go into platform-specific implementations using Go build tags - memory_linux.go: Linux implementation (cgroups v1/v2, /proc/meminfo) - memory_darwin.go: macOS implementation (sysctl VM statistics) - memory.go: Shared caching logic for all platforms - Added comprehensive platform-specific tests - memory_linux_test.go: Tests for cgroup and /proc/meminfo - memory_darwin_test.go: Tests for sysctl memory retrieval - memory_test.go: Platform-agnostic tests - Updated CI/CD workflow (.github/workflows/go.yml) - Added matrix strategy to test both ubuntu-latest and macos-latest - Platform-specific test verification for both Linux and macOS - Cross-compilation checks to ensure compatibility - Made spooled tests deterministic with memory mocking - Added mockMemoryUsage() helper for DRY test code - Fixed 8 tests that were failing on high-memory systems - Added 3 new threshold boundary tests (49%, 50%, 51%) - All tests now pass regardless of host system memory state The macOS implementation uses sysctl to query VM statistics and calculates memory usage as: (total - free - purgeable - speculative) / total All tests pass on both platforms with proper mocking ensuring deterministic behavior independent of actual system memory usage. * Fix test failure from cache state pollution between test packages Add ResetMemoryCache() function and call it in test cleanup to prevent global memoryUsageCache state from persisting across test packages. When running 'go test ./...', tests from pkg/spooledtempfile that mock memory usage (mocking high memory values like 60%) would pollute the global cache, causing TestHTTPClientRequestFailing to fail when run after them. The fix adds: 1. ResetMemoryCache() - clears the global cache (lastChecked and lastFraction) 2. Calls ResetMemoryCache() in mockMemoryUsage() cleanup to ensure clean state This maintains the performance benefits of the global cache while ensuring test isolation through explicit cleanup using t.Cleanup(). * Fix PR #150 review feedback: uint64 underflow, hardcoded paths, and endianness Address three issues identified in code review: 1. **Fix potential uint64 underflow** (memory_darwin.go): - If reclaimablePages > totalPages, subtraction would wrap to large number - Now clamps usedPages to 0 when reclaimable pages exceed total - Adds explicit bounds checking to prevent invalid memory fractions 2. **Fix hardcoded error path** (memory_linux.go): - Changed error message from literal "/proc/meminfo" to use procMeminfoPath variable - Improves diagnostic accuracy when paths are overridden in tests 3. **Simplify endianness handling** (memory_darwin.go): - Removed custom getSysctlUint32() helper function - Now uses unix.SysctlUint32() directly which handles endianness correctly - Removed unnecessary encoding/binary import - Updated tests to use unix.SysctlUint32() directly All tests pass. Cross-compilation verified for both Linux and macOS. * Fix testdata path resolution in mend tests to work from any directory Replace hardcoded relative paths with dynamic path resolution using runtime.Caller() to compute the testdata directory relative to the test file location. This fixes tests being skipped in CI/CD when running 'go test ./...' from the root directory. Tests now correctly find and run against testdata files. Changes: - Added getTestdataDir() helper function using runtime.Caller() - Replaced 5 hardcoded "../../testdata/warcs" paths with getTestdataDir() - Added runtime package import - Tests now run from any working directory (root, CI/CD, etc.) All TestAnalyzeWARCFile subtests now run and pass instead of being skipped. * Fix TestMendFunctionDirect path resolution to work in CI/CD Replace os.Getwd() + relative path construction with getTestdataDir() helper to make TestMendFunctionDirect work from any directory. This fixes 4 skipped subtests: - good.warc.gz.open - empty.warc.gz.open - corrupted-trailing-bytes.warc.gz.open - corrupted-mid-record.warc.gz.open These tests now run properly in CI/CD when 'go test ./...' is run from root. * Fix TestHTTPClient byte range tolerance for macOS compatibility Widen the acceptable byte range in TestHTTPClient from 27130-27160 to 27130-27170 to accommodate platform-specific differences in HTTP response headers between macOS and Linux. The test was failing on macOS with 27163 bytes due to platform-specific header variations, which is normal behavior across different operating systems. * Fix WARC version parsing and mend expectations * Extend CI job timeout to 15 minutes * Make HTTP client tests deterministic * Address copilot feedback on PR #150 - Remove unused vm.page_pageable_internal/external_count sysctls that cause failures on some macOS versions - Fix data races in memory_test.go by using ResetMemoryCache() and proper mutex locking - Fix cache pointer reassignment in spooled_test.go to prevent race conditions - Update CI test filter to reference only existing tests (TestMemoryFractionConsistency instead of TestGetSysctlUint32) * Limit macOS CI to spooledtempfile tests only Run only spooledtempfile package tests on macOS runners instead of the full test suite, since the macOS-specific code changes are limited to that package. This addresses review feedback to optimize CI runtime while maintaining platform-specific test coverage. * Add WARC format regression smoke test Introduce TestSmokeWARCFormatRegression to validate WARC format consistency using a frozen reference file (testdata/test.warc.gz). This test checks exact byte counts, record counts, Content-Length values, and digest hashes against known-good values. This complements the existing dynamic tests by providing explicit validation that the WARC format hasn't changed, addressing the concern about byte-level format regression detection while keeping the main integration tests maintainable. --- .github/workflows/go.yml | 41 ++- client_test.go | 88 ++++- cmd/warc/mend/mend_test.go | 37 ++- pkg/spooledtempfile/memory.go | 206 ++---------- pkg/spooledtempfile/memory_darwin.go | 77 +++++ pkg/spooledtempfile/memory_darwin_test.go | 91 ++++++ pkg/spooledtempfile/memory_linux.go | 197 +++++++++++ pkg/spooledtempfile/memory_linux_test.go | 364 +++++++++++++++++++++ pkg/spooledtempfile/memory_test.go | 377 +++------------------- pkg/spooledtempfile/spooled.go | 33 -- pkg/spooledtempfile/spooled_test.go | 116 ++++++- read.go | 6 +- smoke_test.go | 154 +++++++++ 13 files changed, 1205 insertions(+), 582 deletions(-) create mode 100644 pkg/spooledtempfile/memory_darwin.go create mode 100644 pkg/spooledtempfile/memory_darwin_test.go create mode 100644 pkg/spooledtempfile/memory_linux.go create mode 100644 pkg/spooledtempfile/memory_linux_test.go create mode 100644 smoke_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index c029d61..f50d34b 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -11,8 +11,11 @@ permissions: jobs: build: - runs-on: ubuntu-latest - timeout-minutes: 5 + runs-on: ${{ matrix.os }} + timeout-minutes: 15 + strategy: + matrix: + os: [ubuntu-latest, macos-latest] steps: - uses: actions/checkout@v5 @@ -25,12 +28,42 @@ jobs: run: go build -v ./... - name: Goroutine leak detector + if: matrix.os == 'ubuntu-latest' continue-on-error: true run: go test -c -o tests && for test in $(go test -list . | grep -E "^(Test|Example)"); do ./tests -test.run "^$test\$" &>/dev/null && echo -e "$test passed\n" || echo -e "$test failed\n"; done - - name: Test + - name: Test (Full Suite) + if: matrix.os == 'ubuntu-latest' run: go test -race -v ./... + - name: Test (spooledtempfile only) + if: matrix.os == 'macos-latest' + run: go test -race -v ./pkg/spooledtempfile/... + - name: Benchmarks + if: matrix.os == 'ubuntu-latest' run: go test -bench=. -benchmem -run=^$ ./... - + + # Platform-specific test verification + - name: Test Linux-specific memory implementation + if: matrix.os == 'ubuntu-latest' + run: | + echo "Running Linux-specific memory tests..." + cd pkg/spooledtempfile + go test -v -run "TestCgroup|TestHostMeminfo|TestRead" + + - name: Test macOS-specific memory implementation + if: matrix.os == 'macos-latest' + run: | + echo "Running macOS-specific memory tests..." + cd pkg/spooledtempfile + go test -v -run "TestGetSystemMemoryUsedFraction|TestSysctlMemoryValues|TestMemoryFractionConsistency" + + # Cross-compilation verification + - name: Cross-compile for macOS (from Linux) + if: matrix.os == 'ubuntu-latest' + run: GOOS=darwin GOARCH=amd64 go build ./... + + - name: Cross-compile for Linux (from macOS) + if: matrix.os == 'macos-latest' + run: GOOS=linux GOARCH=amd64 go build ./... diff --git a/client_test.go b/client_test.go index 669c989..38bda12 100644 --- a/client_test.go +++ b/client_test.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "errors" + "fmt" "io" "math/big" "net" @@ -65,6 +66,46 @@ func defaultBenchmarkRotatorSettings(t *testing.B) *RotatorSettings { return rotatorSettings } +// sumRecordContentLengths returns the total Content-Length across all records in a WARC file. +func sumRecordContentLengths(path string) (int64, error) { + file, err := os.Open(path) + if err != nil { + return 0, err + } + defer file.Close() + + reader, err := NewReader(file) + if err != nil { + return 0, err + } + + var total int64 + for { + record, err := reader.ReadRecord() + if err != nil { + if err == io.EOF { + break + } + return 0, err + } + + clStr := record.Header.Get("Content-Length") + cl, err := strconv.ParseInt(clStr, 10, 64) + if err != nil { + record.Content.Close() + return 0, fmt.Errorf("parsing Content-Length %q: %w", clStr, err) + } + + total += cl + + if err := record.Content.Close(); err != nil { + return 0, err + } + } + + return total, nil +} + // Helper function used in many tests func drainErrChan(t *testing.T, errChan chan *Error) func() { var wg sync.WaitGroup @@ -155,21 +196,27 @@ func TestHTTPClient(t *testing.T) { t.Fatal(err) } + var expectedPayloadBytes int64 for _, path := range files { testFileSingleHashCheck(t, path, "sha1:UIRWL5DFIPQ4MX3D3GFHM2HCVU3TZ6I3", []string{"26872"}, 1, server.URL+"/testdata/image.svg") + + totalBytes, err := sumRecordContentLengths(path) + if err != nil { + t.Fatalf("failed to sum record content lengths for %s: %v", path, err) + } + expectedPayloadBytes += totalBytes } // verify that the remote dedupe count is correct dataTotal := httpClient.DataTotal.Load() - if dataTotal < 27130 || dataTotal > 27160 { - t.Fatalf("total bytes downloaded mismatch, expected: 27130-27160 got: %d", dataTotal) + if dataTotal != expectedPayloadBytes { + t.Fatalf("total bytes downloaded mismatch, expected %d got %d", expectedPayloadBytes, dataTotal) } } func TestHTTPClientRequestFailing(t *testing.T) { var ( rotatorSettings = defaultRotatorSettings(t) - errWg sync.WaitGroup err error ) @@ -182,11 +229,14 @@ func TestHTTPClientRequestFailing(t *testing.T) { if err != nil { t.Fatalf("Unable to init WARC writing HTTP client: %s", err) } - errWg.Add(1) + + errCh := make(chan *Error, 1) + var errChWg sync.WaitGroup + errChWg.Add(1) go func() { - defer errWg.Done() - for _ = range httpClient.ErrChan { - // We expect an error here, so we don't need to log it + defer errChWg.Done() + for err := range httpClient.ErrChan { + errCh <- err } }() @@ -201,10 +251,21 @@ func TestHTTPClientRequestFailing(t *testing.T) { _, err = httpClient.Do(req) if err == nil { - t.Fatal("expected error on Do, got none") + select { + case recv := <-errCh: + if recv == nil { + t.Fatal("expected error via ErrChan but channel closed without value") + } + case <-time.After(2 * time.Second): + t.Fatal("expected error on Do or via ErrChan, got none") + } + } else { + t.Logf("got expected error: %v", err) } httpClient.Close() + errChWg.Wait() + close(errCh) } func TestHTTPClientConnReadDeadline(t *testing.T) { @@ -596,15 +657,15 @@ func TestHTTPClientWithProxy(t *testing.T) { // init socks5 proxy server proxyServer := socks5.NewServer() + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen for proxy: %v", err) + } // Create a channel to signal server stop stopChan := make(chan struct{}) go func() { - listener, err := net.Listen("tcp", "127.0.0.1:8000") - if err != nil { - panic(err) - } defer listener.Close() go func() { @@ -617,6 +678,7 @@ func TestHTTPClientWithProxy(t *testing.T) { } }() + proxyAddr := listener.Addr().String() // Defer sending the stop signal defer close(stopChan) @@ -627,7 +689,7 @@ func TestHTTPClientWithProxy(t *testing.T) { // init the HTTP client responsible for recording HTTP(s) requests / responses httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{ RotatorSettings: rotatorSettings, - Proxy: "socks5://127.0.0.1:8000"}) + Proxy: fmt.Sprintf("socks5://%s", proxyAddr)}) if err != nil { t.Fatalf("Unable to init WARC writing HTTP client: %s", err) } diff --git a/cmd/warc/mend/mend_test.go b/cmd/warc/mend/mend_test.go index 5750f85..f72783d 100644 --- a/cmd/warc/mend/mend_test.go +++ b/cmd/warc/mend/mend_test.go @@ -6,15 +6,25 @@ import ( "io" "os" "path/filepath" + "runtime" "testing" "github.com/internetarchive/gowarc/cmd/warc/verify" "github.com/spf13/cobra" ) +// getTestdataDir returns the path to the testdata directory, resolved relative to this test file. +// This ensures tests work regardless of the working directory (e.g., from root, CI/CD, etc.). +// Test file is at: cmd/warc/mend/mend_test.go, testdata is at: testdata/warcs +// So we need to go up 3 levels from the test file. +func getTestdataDir() string { + _, filename, _, _ := runtime.Caller(1) + return filepath.Join(filepath.Dir(filename), "../../../testdata/warcs") +} + // TestAnalyzeWARCFile tests the analysis of different WARC files func TestAnalyzeWARCFile(t *testing.T) { - testdataDir := "../../testdata/warcs" + testdataDir := getTestdataDir() tests := []struct { name string @@ -128,7 +138,7 @@ func TestAnalyzeWARCFile(t *testing.T) { // TestMendResultValidation tests that mendResult structs are properly populated func TestMendResultValidation(t *testing.T) { - testdataDir := "../../testdata/warcs" + testdataDir := getTestdataDir() // Test a file that should have all fields populated filePath := filepath.Join(testdataDir, "corrupted-trailing-bytes.warc.gz.open") @@ -183,7 +193,7 @@ func TestMendResultValidation(t *testing.T) { // TestAnalyzeWARCFileForceMode tests analyzeWARCFile with force=true on good closed WARC files func TestAnalyzeWARCFileForceMode(t *testing.T) { - testdataDir := "../../testdata/warcs" + testdataDir := getTestdataDir() tests := []struct { name string @@ -255,7 +265,7 @@ func TestAnalyzeWARCFileForceMode(t *testing.T) { // TestSkipNonOpenFiles tests that non-.open files are correctly skipped func TestSkipNonOpenFiles(t *testing.T) { - testdataDir := "../../testdata/warcs" + testdataDir := getTestdataDir() filePath := filepath.Join(testdataDir, "skip-non-open.warc.gz") // Check if test file exists @@ -305,7 +315,7 @@ var mendExpectedResults = map[string]expectedResult{ recordCount: 1, // Actual count from mend operation truncateAt: 0, // No truncation needed description: "good synthetic file with .open suffix", - shouldBeValid: false, // File has WARC header corruption that mend can't fix + shouldBeValid: true, // After removing the .open suffix the WARC remains valid }, "empty.warc.gz.open": { outputFile: "empty.warc.gz", @@ -321,7 +331,7 @@ var mendExpectedResults = map[string]expectedResult{ recordCount: 1, // Actual count from mend operation truncateAt: 2362, // Truncates trailing garbage description: "synthetic file with trailing garbage bytes", - shouldBeValid: false, // File has WARC header corruption that mend can't fix + shouldBeValid: true, // Truncating the trailing garbage yields a valid WARC record }, "corrupted-mid-record.warc.gz.open": { outputFile: "corrupted-mid-record.warc.gz", @@ -329,7 +339,7 @@ var mendExpectedResults = map[string]expectedResult{ recordCount: 1, // Actual count from mend operation truncateAt: 1219, description: "synthetic file corrupted mid-record", - shouldBeValid: false, // File has WARC header corruption that mend can't fix + shouldBeValid: true, // Truncating back to the last valid position restores a valid record }, } @@ -359,14 +369,7 @@ func createMockCobraCommand() *cobra.Command { // TestMendFunctionDirect verifies that the mend function produces // expected results on synthetic test data by comparing against pre-computed checksums func TestMendFunctionDirect(t *testing.T) { - // Get current directory and construct paths relative to workspace root - cwd, err := os.Getwd() - if err != nil { - t.Fatalf("failed to get current directory: %v", err) - } - // From cmd/mend, go up to workspace root - workspaceRoot := filepath.Join(cwd, "../..") - testdataDir := filepath.Join(workspaceRoot, "testdata/warcs") + testdataDir := getTestdataDir() outputDir := filepath.Join(testdataDir, "mend_test_output") // Ensure output directory exists @@ -505,7 +508,7 @@ func copyFile(src, dst string) error { // TestIsGzipFile tests the gzip file detection function func TestIsGzipFile(t *testing.T) { - testdataDir := "../../testdata/warcs" + testdataDir := getTestdataDir() tests := []struct { name string @@ -643,7 +646,7 @@ func TestConfirmAction(t *testing.T) { // TestMendDryRun tests the mend function in dry-run mode func TestMendDryRun(t *testing.T) { - testdataDir := "../../testdata/warcs" + testdataDir := getTestdataDir() tempDir, err := os.MkdirTemp("", "mend_dry_run_test_*") if err != nil { t.Fatalf("failed to create temp dir: %v", err) diff --git a/pkg/spooledtempfile/memory.go b/pkg/spooledtempfile/memory.go index 3309bc5..9269893 100644 --- a/pkg/spooledtempfile/memory.go +++ b/pkg/spooledtempfile/memory.go @@ -1,195 +1,53 @@ package spooledtempfile import ( - "bufio" - "fmt" - "os" - "strconv" - "strings" + "sync" + "time" ) -// Overridable in tests: -var ( - cgv2UsagePath = "/sys/fs/cgroup/memory.current" - cgv2HighPath = "/sys/fs/cgroup/memory.high" - cgv2MaxPath = "/sys/fs/cgroup/memory.max" - - cgv1UsagePath = "/sys/fs/cgroup/memory/memory.usage_in_bytes" - cgv1LimitPath = "/sys/fs/cgroup/memory/memory.limit_in_bytes" - - procMeminfoPath = "/proc/meminfo" +const ( + // memoryCheckInterval defines how often we check system memory usage. + memoryCheckInterval = 500 * time.Millisecond ) -// getSystemMemoryUsedFraction returns used/limit for the container if -// cgroup limits are set; otherwise falls back to host /proc/meminfo. -var getSystemMemoryUsedFraction = func() (float64, error) { - probes := []func() (float64, bool, error){ - cgroupV2UsedFraction, - cgroupV1UsedFraction, - } - - for _, p := range probes { - if frac, ok, err := p(); err != nil { - return 0, err - } else if ok { - return frac, nil - } - } - - return hostMeminfoUsedFraction() -} - -func cgroupV2UsedFraction() (frac float64, ok bool, err error) { - usage, uok, err := readUint64FileIfExists(cgv2UsagePath) - if err != nil { - return 0, false, err - } - if !uok { - return 0, false, nil // not cgroup v2 (or not accessible) - } - - // Try memory.high first - highStr, hok, err := readStringFileIfExists(cgv2HighPath) - if err != nil { - return 0, false, err - } - - var high, max uint64 - var haveHigh bool - if hok { - hs := strings.TrimSpace(highStr) - if hs != "" && hs != "max" { - if v, e := strconv.ParseUint(hs, 10, 64); e == nil && v > 0 { - high, haveHigh = v, true - } - } - } - - // Always read memory.max as fallback (and for sanity checks) - maxStr, mok, err := readStringFileIfExists(cgv2MaxPath) - if err != nil { - return 0, false, err - } - var haveMax bool - if mok { - ms := strings.TrimSpace(maxStr) - if ms != "" && ms != "max" { - if v, e := strconv.ParseUint(ms, 10, 64); e == nil && v > 0 { - max, haveMax = v, true - } - } - } - - // Choose denominator: prefer valid 'high' unless it is >= max. - switch { - case haveHigh && haveMax && high < max: - return float64(usage) / float64(high), true, nil - case haveMax: - return float64(usage) / float64(max), true, nil - case haveHigh: - return float64(usage) / float64(high), true, nil - default: - return 0, false, nil // no effective limit - } -} - -func cgroupV1UsedFraction() (frac float64, ok bool, err error) { - usage, uok, err := readUint64FileIfExists(cgv1UsagePath) - if err != nil { - return 0, false, err - } - - limit, lok, err := readUint64FileIfExists(cgv1LimitPath) - if err != nil { - return 0, false, err - } - if !uok || !lok || limit == 0 { - return 0, false, nil - } - - // Some kernels report a huge limit (e.g., ~max uint64) to mean "no limit" - if limit > (1 << 60) { // heuristic ~ 1 exabyte - return 0, false, nil - } - - return float64(usage) / float64(limit), true, nil +type globalMemoryCache struct { + sync.Mutex + lastChecked time.Time + lastFraction float64 } -func hostMeminfoUsedFraction() (float64, error) { - f, err := os.Open(procMeminfoPath) - if err != nil { - return 0, fmt.Errorf("failed to open /proc/meminfo: %v", err) - } - defer f.Close() +var ( + memoryUsageCache = &globalMemoryCache{} +) - var memTotal, memAvailable, memFree, buffers, cached uint64 - sc := bufio.NewScanner(f) - for sc.Scan() { - line := sc.Text() - fields := strings.Fields(line) - if len(fields) < 2 { - continue - } - key := strings.TrimRight(fields[0], ":") - val, _ := strconv.ParseUint(fields[1], 10, 64) // kB - switch key { - case "MemTotal": - memTotal = val - case "MemAvailable": - memAvailable = val - case "MemFree": - memFree = val - case "Buffers": - buffers = val - case "Cached": - cached = val - } - } - if err := sc.Err(); err != nil { - return 0, fmt.Errorf("scanner error reading /proc/meminfo: %v", err) - } - if memTotal == 0 { - return 0, fmt.Errorf("could not find MemTotal in /proc/meminfo") - } +// getCachedMemoryUsage returns the cached memory usage fraction, or fetches a new one +// if the cache has expired. This reduces the overhead of checking memory usage on every +// write operation. +func getCachedMemoryUsage() (float64, error) { + memoryUsageCache.Lock() + defer memoryUsageCache.Unlock() - var used uint64 - if memAvailable > 0 { - used = memTotal - memAvailable - } else { - approxAvailable := memFree + buffers + cached - used = memTotal - approxAvailable + if time.Since(memoryUsageCache.lastChecked) < memoryCheckInterval { + return memoryUsageCache.lastFraction, nil } - // meminfo is in kB; unit cancels in the fraction - return float64(used) / float64(memTotal), nil -} - -func readUint64FileIfExists(path string) (val uint64, ok bool, err error) { - data, err := os.ReadFile(path) + fraction, err := getSystemMemoryUsedFraction() if err != nil { - if os.IsNotExist(err) { - return 0, false, nil - } - return 0, false, err + return 0, err } - // v2 may use "max"; caller handles that as not-ok - v, perr := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) - if perr != nil { - return 0, false, nil - } + memoryUsageCache.lastChecked = time.Now() + memoryUsageCache.lastFraction = fraction - return v, true, nil + return fraction, nil } -func readStringFileIfExists(path string) (string, bool, error) { - data, err := os.ReadFile(path) - if err != nil { - if os.IsNotExist(err) { - return "", false, nil - } - return "", false, err - } +// ResetMemoryCache clears the cached memory usage state. This is primarily used in tests +// to prevent state pollution between test packages. +func ResetMemoryCache() { + memoryUsageCache.Lock() + defer memoryUsageCache.Unlock() - return string(data), true, nil + memoryUsageCache.lastChecked = time.Time{} + memoryUsageCache.lastFraction = 0 } diff --git a/pkg/spooledtempfile/memory_darwin.go b/pkg/spooledtempfile/memory_darwin.go new file mode 100644 index 0000000..22b4848 --- /dev/null +++ b/pkg/spooledtempfile/memory_darwin.go @@ -0,0 +1,77 @@ +//go:build darwin + +package spooledtempfile + +import ( + "fmt" + + "golang.org/x/sys/unix" +) + +// getSystemMemoryUsedFraction returns the fraction of system memory currently in use on macOS. +// It uses sysctl to query system memory statistics. +var getSystemMemoryUsedFraction = func() (float64, error) { + // Get total physical memory using sysctl + totalBytes, err := unix.SysctlUint64("hw.memsize") + if err != nil { + return 0, fmt.Errorf("failed to get hw.memsize: %w", err) + } + + if totalBytes == 0 { + return 0, fmt.Errorf("hw.memsize returned 0") + } + + // Get page size + pageSize, err := unix.SysctlUint32("vm.pagesize") + if err != nil { + return 0, fmt.Errorf("failed to get vm.pagesize: %w", err) + } + + // Get page counts using the sysctl values that actually exist on macOS + // Note: macOS doesn't expose vm.page_active_count or vm.page_wire_count via sysctl + // We use the available values: + // - vm.page_free_count: free pages + // - vm.page_purgeable_count: purgeable pages (can be reclaimed) + // - vm.page_speculative_count: speculative pages + + freePages, err := unix.SysctlUint32("vm.page_free_count") + if err != nil { + return 0, fmt.Errorf("failed to get vm.page_free_count: %w", err) + } + + purgeablePages, err := unix.SysctlUint32("vm.page_purgeable_count") + if err != nil { + return 0, fmt.Errorf("failed to get vm.page_purgeable_count: %w", err) + } + + speculativePages, err := unix.SysctlUint32("vm.page_speculative_count") + if err != nil { + return 0, fmt.Errorf("failed to get vm.page_speculative_count: %w", err) + } + + // Calculate used memory + // Used = Total - (Free + Purgeable + Speculative) + totalPages := totalBytes / uint64(pageSize) + reclaimablePages := uint64(freePages) + uint64(purgeablePages) + uint64(speculativePages) + + // Clamp to prevent underflow: if reclaimable > total, use total + var usedPages uint64 + if reclaimablePages < totalPages { + usedPages = totalPages - reclaimablePages + } else { + usedPages = 0 + } + + usedBytes := usedPages * uint64(pageSize) + + // Calculate fraction + fraction := float64(usedBytes) / float64(totalBytes) + + // Sanity check: fraction should be between 0 and 1 + if fraction < 0 || fraction > 1 { + return 0, fmt.Errorf("calculated memory fraction out of range: %v (used: %d, total: %d)", + fraction, usedBytes, totalBytes) + } + + return fraction, nil +} diff --git a/pkg/spooledtempfile/memory_darwin_test.go b/pkg/spooledtempfile/memory_darwin_test.go new file mode 100644 index 0000000..2c585ee --- /dev/null +++ b/pkg/spooledtempfile/memory_darwin_test.go @@ -0,0 +1,91 @@ +//go:build darwin + +package spooledtempfile + +import ( + "testing" + + "golang.org/x/sys/unix" +) + +// TestGetSystemMemoryUsedFraction verifies that the macOS implementation +// returns a valid memory fraction between 0 and 1. +func TestGetSystemMemoryUsedFraction(t *testing.T) { + fraction, err := getSystemMemoryUsedFraction() + if err != nil { + t.Fatalf("getSystemMemoryUsedFraction() failed: %v", err) + } + + if fraction < 0 || fraction > 1 { + t.Fatalf("memory fraction out of range: got %v, want 0.0-1.0", fraction) + } + + // Log the result for informational purposes + t.Logf("Current system memory usage: %.2f%%", fraction*100) +} + +// TestSysctlMemoryValues verifies that we can successfully retrieve memory values via sysctl. +func TestSysctlMemoryValues(t *testing.T) { + // Test hw.memsize + totalBytes, err := unix.SysctlUint64("hw.memsize") + if err != nil { + t.Fatalf("failed to get hw.memsize: %v", err) + } + if totalBytes == 0 { + t.Fatal("hw.memsize returned 0") + } + t.Logf("Total memory: %d bytes (%.2f GB)", totalBytes, float64(totalBytes)/(1024*1024*1024)) + + // Test vm.pagesize + pageSize, err := unix.SysctlUint32("vm.pagesize") + if err != nil { + t.Fatalf("failed to get vm.pagesize: %v", err) + } + if pageSize == 0 { + t.Fatal("vm.pagesize returned 0") + } + t.Logf("Page size: %d bytes", pageSize) + + // Test page counts + freePages, err := unix.SysctlUint32("vm.page_free_count") + if err != nil { + t.Fatalf("failed to get vm.page_free_count: %v", err) + } + t.Logf("Free pages: %d (%.2f MB)", freePages, float64(freePages*pageSize)/(1024*1024)) + + pageableInternal, err := unix.SysctlUint32("vm.page_pageable_internal_count") + if err != nil { + t.Fatalf("failed to get vm.page_pageable_internal_count: %v", err) + } + t.Logf("Pageable internal pages: %d (%.2f MB)", pageableInternal, float64(pageableInternal*pageSize)/(1024*1024)) + + pageableExternal, err := unix.SysctlUint32("vm.page_pageable_external_count") + if err != nil { + t.Fatalf("failed to get vm.page_pageable_external_count: %v", err) + } + t.Logf("Pageable external pages: %d (%.2f MB)", pageableExternal, float64(pageableExternal*pageSize)/(1024*1024)) +} + +// TestMemoryFractionConsistency verifies that multiple calls return consistent values. +func TestMemoryFractionConsistency(t *testing.T) { + const calls = 5 + var fractions [calls]float64 + + for i := 0; i < calls; i++ { + frac, err := getSystemMemoryUsedFraction() + if err != nil { + t.Fatalf("call %d failed: %v", i, err) + } + fractions[i] = frac + } + + // Check that all values are within a reasonable range of each other + // Memory usage shouldn't vary wildly between consecutive calls + for i := 1; i < calls; i++ { + diff := fractions[i] - fractions[i-1] + if diff < -0.2 || diff > 0.2 { + t.Errorf("memory fraction changed too much between calls: %v -> %v (diff: %v)", + fractions[i-1], fractions[i], diff) + } + } +} diff --git a/pkg/spooledtempfile/memory_linux.go b/pkg/spooledtempfile/memory_linux.go new file mode 100644 index 0000000..8ccd101 --- /dev/null +++ b/pkg/spooledtempfile/memory_linux.go @@ -0,0 +1,197 @@ +//go:build linux + +package spooledtempfile + +import ( + "bufio" + "fmt" + "os" + "strconv" + "strings" +) + +// Overridable in tests: +var ( + cgv2UsagePath = "/sys/fs/cgroup/memory.current" + cgv2HighPath = "/sys/fs/cgroup/memory.high" + cgv2MaxPath = "/sys/fs/cgroup/memory.max" + + cgv1UsagePath = "/sys/fs/cgroup/memory/memory.usage_in_bytes" + cgv1LimitPath = "/sys/fs/cgroup/memory/memory.limit_in_bytes" + + procMeminfoPath = "/proc/meminfo" +) + +// getSystemMemoryUsedFraction returns used/limit for the container if +// cgroup limits are set; otherwise falls back to host /proc/meminfo. +var getSystemMemoryUsedFraction = func() (float64, error) { + probes := []func() (float64, bool, error){ + cgroupV2UsedFraction, + cgroupV1UsedFraction, + } + + for _, p := range probes { + if frac, ok, err := p(); err != nil { + return 0, err + } else if ok { + return frac, nil + } + } + + return hostMeminfoUsedFraction() +} + +func cgroupV2UsedFraction() (frac float64, ok bool, err error) { + usage, uok, err := readUint64FileIfExists(cgv2UsagePath) + if err != nil { + return 0, false, err + } + if !uok { + return 0, false, nil // not cgroup v2 (or not accessible) + } + + // Try memory.high first + highStr, hok, err := readStringFileIfExists(cgv2HighPath) + if err != nil { + return 0, false, err + } + + var high, max uint64 + var haveHigh bool + if hok { + hs := strings.TrimSpace(highStr) + if hs != "" && hs != "max" { + if v, e := strconv.ParseUint(hs, 10, 64); e == nil && v > 0 { + high, haveHigh = v, true + } + } + } + + // Always read memory.max as fallback (and for sanity checks) + maxStr, mok, err := readStringFileIfExists(cgv2MaxPath) + if err != nil { + return 0, false, err + } + var haveMax bool + if mok { + ms := strings.TrimSpace(maxStr) + if ms != "" && ms != "max" { + if v, e := strconv.ParseUint(ms, 10, 64); e == nil && v > 0 { + max, haveMax = v, true + } + } + } + + // Choose denominator: prefer valid 'high' unless it is >= max. + switch { + case haveHigh && haveMax && high < max: + return float64(usage) / float64(high), true, nil + case haveMax: + return float64(usage) / float64(max), true, nil + case haveHigh: + return float64(usage) / float64(high), true, nil + default: + return 0, false, nil // no effective limit + } +} + +func cgroupV1UsedFraction() (frac float64, ok bool, err error) { + usage, uok, err := readUint64FileIfExists(cgv1UsagePath) + if err != nil { + return 0, false, err + } + + limit, lok, err := readUint64FileIfExists(cgv1LimitPath) + if err != nil { + return 0, false, err + } + if !uok || !lok || limit == 0 { + return 0, false, nil + } + + // Some kernels report a huge limit (e.g., ~max uint64) to mean "no limit" + if limit > (1 << 60) { // heuristic ~ 1 exabyte + return 0, false, nil + } + + return float64(usage) / float64(limit), true, nil +} + +func hostMeminfoUsedFraction() (float64, error) { + f, err := os.Open(procMeminfoPath) + if err != nil { + return 0, fmt.Errorf("failed to open %s: %v", procMeminfoPath, err) + } + defer f.Close() + + var memTotal, memAvailable, memFree, buffers, cached uint64 + sc := bufio.NewScanner(f) + for sc.Scan() { + line := sc.Text() + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + key := strings.TrimRight(fields[0], ":") + val, _ := strconv.ParseUint(fields[1], 10, 64) // kB + switch key { + case "MemTotal": + memTotal = val + case "MemAvailable": + memAvailable = val + case "MemFree": + memFree = val + case "Buffers": + buffers = val + case "Cached": + cached = val + } + } + if err := sc.Err(); err != nil { + return 0, fmt.Errorf("scanner error reading /proc/meminfo: %v", err) + } + if memTotal == 0 { + return 0, fmt.Errorf("could not find MemTotal in /proc/meminfo") + } + + var used uint64 + if memAvailable > 0 { + used = memTotal - memAvailable + } else { + approxAvailable := memFree + buffers + cached + used = memTotal - approxAvailable + } + + // meminfo is in kB; unit cancels in the fraction + return float64(used) / float64(memTotal), nil +} + +func readUint64FileIfExists(path string) (val uint64, ok bool, err error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return 0, false, nil + } + return 0, false, err + } + + // v2 may use "max"; caller handles that as not-ok + v, perr := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) + if perr != nil { + return 0, false, nil + } + + return v, true, nil +} + +func readStringFileIfExists(path string) (string, bool, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return "", false, nil + } + return "", false, err + } + + return string(data), true, nil +} diff --git a/pkg/spooledtempfile/memory_linux_test.go b/pkg/spooledtempfile/memory_linux_test.go new file mode 100644 index 0000000..c0e3efb --- /dev/null +++ b/pkg/spooledtempfile/memory_linux_test.go @@ -0,0 +1,364 @@ +//go:build linux + +package spooledtempfile + +import ( + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +// --- helpers --- + +type savedPaths struct { + v2Usage, v2High, v2Max string + v1Usage, v1Limit string + meminfo string +} + +func writeFile(t *testing.T, dir, name, content string) string { + t.Helper() + p := filepath.Join(dir, name) + if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(p, []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", p, err) + } + return p +} + +func saveAndRedirectPaths(t *testing.T, base string) (restore func()) { + t.Helper() + + old := savedPaths{ + v2Usage: cgv2UsagePath, v2High: cgv2HighPath, v2Max: cgv2MaxPath, + v1Usage: cgv1UsagePath, v1Limit: cgv1LimitPath, + meminfo: procMeminfoPath, + } + + // Point to files under base; tests will create only what they need. + cgv2UsagePath = filepath.Join(base, "sys/fs/cgroup/memory.current") + cgv2HighPath = filepath.Join(base, "sys/fs/cgroup/memory.high") + cgv2MaxPath = filepath.Join(base, "sys/fs/cgroup/memory.max") + cgv1UsagePath = filepath.Join(base, "sys/fs/cgroup/memory/memory.usage_in_bytes") + cgv1LimitPath = filepath.Join(base, "sys/fs/cgroup/memory/memory.limit_in_bytes") + procMeminfoPath = filepath.Join(base, "proc/meminfo") + + return func() { + cgv2UsagePath, cgv2HighPath, cgv2MaxPath = old.v2Usage, old.v2High, old.v2Max + cgv1UsagePath, cgv1LimitPath = old.v1Usage, old.v1Limit + procMeminfoPath = old.meminfo + } +} + +// --- cgroup v2 --- + +func TestCgroupV2_UsesHighWhenStricterThanMax(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + // usage=400 + writeFile(t, dir, "sys/fs/cgroup/memory.current", "400") + // high=800, max=1000 -> use high, frac = 0.5 + writeFile(t, dir, "sys/fs/cgroup/memory.high", "800") + writeFile(t, dir, "sys/fs/cgroup/memory.max", "1000") + + frac, ok, err := cgroupV2UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("expected ok") + } + if got, want := frac, 0.5; got != want { + t.Fatalf("frac=%v want=%v", got, want) + } +} + +func TestCgroupV2_FallbackToMaxWhenHighIsMax(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + writeFile(t, dir, "sys/fs/cgroup/memory.current", "300") + writeFile(t, dir, "sys/fs/cgroup/memory.high", "max") // unset + writeFile(t, dir, "sys/fs/cgroup/memory.max", "600") // use this + + frac, ok, err := cgroupV2UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("expected ok") + } + if got, want := frac, 0.5; got != want { + t.Fatalf("frac=%v want=%v", got, want) + } +} + +func TestCgroupV2_FallbackToMaxWhenHighInvalid(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + writeFile(t, dir, "sys/fs/cgroup/memory.current", "256") + writeFile(t, dir, "sys/fs/cgroup/memory.high", "not-a-number") + writeFile(t, dir, "sys/fs/cgroup/memory.max", "512") + + frac, ok, err := cgroupV2UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("expected ok") + } + if got, want := frac, 0.5; got != want { + t.Fatalf("frac=%v want=%v", got, want) + } +} + +func TestCgroupV2_UseMaxWhenHighGTE_Max(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + writeFile(t, dir, "sys/fs/cgroup/memory.current", "900") + writeFile(t, dir, "sys/fs/cgroup/memory.high", "1000") // >= max + writeFile(t, dir, "sys/fs/cgroup/memory.max", "1000") + + frac, ok, err := cgroupV2UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("expected ok") + } + if got, want := frac, 0.9; got != want { + t.Fatalf("frac=%v want=%v", got, want) + } +} + +func TestCgroupV2_OnlyHighSet(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + writeFile(t, dir, "sys/fs/cgroup/memory.current", "50") + writeFile(t, dir, "sys/fs/cgroup/memory.high", "100") + // no max file + + frac, ok, err := cgroupV2UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("expected ok") + } + if got, want := frac, 0.5; got != want { + t.Fatalf("frac=%v want=%v", got, want) + } +} + +func TestCgroupV2_NoLimitsOrUsageFile(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + // usage missing -> ok=false (not cgroup v2 / not accessible) + _, ok, err := cgroupV2UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if ok { + t.Fatalf("expected ok=false when usage missing") + } + + // Now create usage but high=max and no max => no effective limit + writeFile(t, dir, "sys/fs/cgroup/memory.current", "123") + writeFile(t, dir, "sys/fs/cgroup/memory.high", "max") + frac, ok, err := cgroupV2UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if ok { + t.Fatalf("expected ok=false with no effective limit, got ok and frac=%v", frac) + } +} + +// --- cgroup v1 --- + +func TestCgroupV1_NormalFraction(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + writeFile(t, dir, "sys/fs/cgroup/memory/memory.usage_in_bytes", "200") + writeFile(t, dir, "sys/fs/cgroup/memory/memory.limit_in_bytes", "400") + + frac, ok, err := cgroupV1UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("expected ok") + } + if got, want := frac, 0.5; got != want { + t.Fatalf("frac=%v want=%v", got, want) + } +} + +func TestCgroupV1_HugeLimitMeansNoLimit(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + writeFile(t, dir, "sys/fs/cgroup/memory/memory.usage_in_bytes", "42") + // > 1<<60 + writeFile(t, dir, "sys/fs/cgroup/memory/memory.limit_in_bytes", strconv.FormatUint((1<<60)+1, 10)) + + _, ok, err := cgroupV1UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if ok { + t.Fatalf("expected ok=false for huge limit (no limit)") + } +} + +func TestCgroupV1_MissingFilesOrZeroLimit(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + // Only usage present -> not ok + writeFile(t, dir, "sys/fs/cgroup/memory/memory.usage_in_bytes", "7") + _, ok, err := cgroupV1UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if ok { + t.Fatalf("expected ok=false when limit missing") + } + + // Now add zero limit -> not ok + writeFile(t, dir, "sys/fs/cgroup/memory/memory.limit_in_bytes", "0") + _, ok, err = cgroupV1UsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + if ok { + t.Fatalf("expected ok=false for zero limit") + } +} + +// --- /proc/meminfo fallback --- + +func TestHostMeminfo_UsesMemAvailableWhenPresent(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + // MemTotal and MemAvailable are in kB + meminfo := strings.Join([]string{ + "MemTotal: 1000000 kB", + "MemAvailable: 250000 kB", + "Buffers: 10000 kB", + "Cached: 20000 kB", + "MemFree: 50000 kB", + }, "\n") + writeFile(t, dir, "proc/meminfo", meminfo) + + frac, err := hostMeminfoUsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + + // used = total - available = 1000000 - 250000 = 750000 => 0.75 + if got, want := frac, 0.75; got != want { + t.Fatalf("frac=%v want=%v", got, want) + } +} + +func TestHostMeminfo_FallbackWithoutMemAvailable(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + meminfo := strings.Join([]string{ + "MemTotal: 1000000 kB", + "MemFree: 100000 kB", + "Buffers: 50000 kB", + "Cached: 150000 kB", + }, "\n") + writeFile(t, dir, "proc/meminfo", meminfo) + + frac, err := hostMeminfoUsedFraction() + if err != nil { + t.Fatalf("err: %v", err) + } + + // approxAvailable = 100000 + 50000 + 150000 = 300000 + // used = 1000000 - 300000 = 700000 => 0.7 + if got, want := frac, 0.7; got != want { + t.Fatalf("frac=%v want=%v", got, want) + } +} + +func TestHostMeminfo_Errors(t *testing.T) { + dir := t.TempDir() + restore := saveAndRedirectPaths(t, dir) + defer restore() + + // Missing file + if _, err := hostMeminfoUsedFraction(); err == nil { + t.Fatalf("expected error when /proc/meminfo missing") + } + + // Present but missing MemTotal + writeFile(t, dir, "proc/meminfo", "MemFree: 1 kB\n") + if _, err := hostMeminfoUsedFraction(); err == nil { + t.Fatalf("expected error when MemTotal missing") + } +} + +// --- read helpers --- + +func TestReadUint64FileIfExists(t *testing.T) { + dir := t.TempDir() + + // Missing -> ok=false, err=nil + if _, ok, err := readUint64FileIfExists(filepath.Join(dir, "nope")); err != nil || ok { + t.Fatalf("expected ok=false, err=nil for missing file; got ok=%v err=%v", ok, err) + } + + // Present & valid + p := writeFile(t, dir, "n.txt", "123\n") + v, ok, err := readUint64FileIfExists(p) + if err != nil || !ok || v != 123 { + t.Fatalf("got v=%d ok=%v err=%v; want 123,true,nil", v, ok, err) + } + + // Present & invalid -> ok=false, err=nil + p = writeFile(t, dir, "bad.txt", "not-a-number") + if _, ok, err := readUint64FileIfExists(p); err != nil || ok { + t.Fatalf("expected ok=false, err=nil for invalid number; got ok=%v err=%v", ok, err) + } +} + +func TestReadStringFileIfExists(t *testing.T) { + dir := t.TempDir() + + if _, ok, err := readStringFileIfExists(filepath.Join(dir, "nope")); err != nil || ok { + t.Fatalf("expected ok=false, err=nil for missing file; got ok=%v err=%v", ok, err) + } + + p := writeFile(t, dir, "s.txt", " hello \n") + s, ok, err := readStringFileIfExists(p) + if err != nil || !ok || strings.TrimSpace(s) != "hello" { + t.Fatalf("got s=%q ok=%v err=%v; want 'hello',true,nil", s, ok, err) + } +} diff --git a/pkg/spooledtempfile/memory_test.go b/pkg/spooledtempfile/memory_test.go index 63a7c45..68a406d 100644 --- a/pkg/spooledtempfile/memory_test.go +++ b/pkg/spooledtempfile/memory_test.go @@ -1,362 +1,83 @@ package spooledtempfile import ( - "os" - "path/filepath" - "strconv" - "strings" "testing" + "time" ) -// --- helpers --- +// TestGetCachedMemoryUsage verifies that the caching mechanism works correctly. +func TestGetCachedMemoryUsage(t *testing.T) { + // Save original function + originalFn := getSystemMemoryUsedFraction -type savedPaths struct { - v2Usage, v2High, v2Max string - v1Usage, v1Limit string - meminfo string -} - -func writeFile(t *testing.T, dir, name, content string) string { - t.Helper() - p := filepath.Join(dir, name) - if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { - t.Fatalf("mkdir: %v", err) - } - if err := os.WriteFile(p, []byte(content), 0o644); err != nil { - t.Fatalf("write %s: %v", p, err) - } - return p -} - -func saveAndRedirectPaths(t *testing.T, base string) (restore func()) { - t.Helper() - - old := savedPaths{ - v2Usage: cgv2UsagePath, v2High: cgv2HighPath, v2Max: cgv2MaxPath, - v1Usage: cgv1UsagePath, v1Limit: cgv1LimitPath, - meminfo: procMeminfoPath, - } - - // Point to files under base; tests will create only what they need. - cgv2UsagePath = filepath.Join(base, "sys/fs/cgroup/memory.current") - cgv2HighPath = filepath.Join(base, "sys/fs/cgroup/memory.high") - cgv2MaxPath = filepath.Join(base, "sys/fs/cgroup/memory.max") - cgv1UsagePath = filepath.Join(base, "sys/fs/cgroup/memory/memory.usage_in_bytes") - cgv1LimitPath = filepath.Join(base, "sys/fs/cgroup/memory/memory.limit_in_bytes") - procMeminfoPath = filepath.Join(base, "proc/meminfo") - - return func() { - cgv2UsagePath, cgv2HighPath, cgv2MaxPath = old.v2Usage, old.v2High, old.v2Max - cgv1UsagePath, cgv1LimitPath = old.v1Usage, old.v1Limit - procMeminfoPath = old.meminfo - } -} - -// --- cgroup v2 --- - -func TestCgroupV2_UsesHighWhenStricterThanMax(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - - // usage=400 - writeFile(t, dir, "sys/fs/cgroup/memory.current", "400") - // high=800, max=1000 -> use high, frac = 0.5 - writeFile(t, dir, "sys/fs/cgroup/memory.high", "800") - writeFile(t, dir, "sys/fs/cgroup/memory.max", "1000") - - frac, ok, err := cgroupV2UsedFraction() - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("expected ok") - } - if got, want := frac, 0.5; got != want { - t.Fatalf("frac=%v want=%v", got, want) - } -} - -func TestCgroupV2_FallbackToMaxWhenHighIsMax(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - - writeFile(t, dir, "sys/fs/cgroup/memory.current", "300") - writeFile(t, dir, "sys/fs/cgroup/memory.high", "max") // unset - writeFile(t, dir, "sys/fs/cgroup/memory.max", "600") // use this - - frac, ok, err := cgroupV2UsedFraction() - if err != nil { - t.Fatalf("err: %v", err) + // Track how many times the function is called + callCount := 0 + getSystemMemoryUsedFraction = func() (float64, error) { + callCount++ + return 0.5, nil } - if !ok { - t.Fatalf("expected ok") - } - if got, want := frac, 0.5; got != want { - t.Fatalf("frac=%v want=%v", got, want) - } -} - -func TestCgroupV2_FallbackToMaxWhenHighInvalid(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - - writeFile(t, dir, "sys/fs/cgroup/memory.current", "256") - writeFile(t, dir, "sys/fs/cgroup/memory.high", "not-a-number") - writeFile(t, dir, "sys/fs/cgroup/memory.max", "512") - - frac, ok, err := cgroupV2UsedFraction() - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("expected ok") - } - if got, want := frac, 0.5; got != want { - t.Fatalf("frac=%v want=%v", got, want) - } -} - -func TestCgroupV2_UseMaxWhenHighGTE_Max(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - - writeFile(t, dir, "sys/fs/cgroup/memory.current", "900") - writeFile(t, dir, "sys/fs/cgroup/memory.high", "1000") // >= max - writeFile(t, dir, "sys/fs/cgroup/memory.max", "1000") - frac, ok, err := cgroupV2UsedFraction() - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("expected ok") - } - if got, want := frac, 0.9; got != want { - t.Fatalf("frac=%v want=%v", got, want) - } -} - -func TestCgroupV2_OnlyHighSet(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - - writeFile(t, dir, "sys/fs/cgroup/memory.current", "50") - writeFile(t, dir, "sys/fs/cgroup/memory.high", "100") - // no max file - - frac, ok, err := cgroupV2UsedFraction() - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("expected ok") - } - if got, want := frac, 0.5; got != want { - t.Fatalf("frac=%v want=%v", got, want) - } -} + // Restore at the end + defer func() { + getSystemMemoryUsedFraction = originalFn + }() -func TestCgroupV2_NoLimitsOrUsageFile(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() + // Reset cache + ResetMemoryCache() - // usage missing -> ok=false (not cgroup v2 / not accessible) - _, ok, err := cgroupV2UsedFraction() + // First call should invoke the function + frac1, err := getCachedMemoryUsage() if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("expected ok=false when usage missing") + t.Fatalf("first call failed: %v", err) } - - // Now create usage but high=max and no max => no effective limit - writeFile(t, dir, "sys/fs/cgroup/memory.current", "123") - writeFile(t, dir, "sys/fs/cgroup/memory.high", "max") - frac, ok, err := cgroupV2UsedFraction() - if err != nil { - t.Fatalf("err: %v", err) + if frac1 != 0.5 { + t.Fatalf("expected 0.5, got %v", frac1) } - if ok { - t.Fatalf("expected ok=false with no effective limit, got ok and frac=%v", frac) + if callCount != 1 { + t.Fatalf("expected 1 call, got %d", callCount) } -} - -// --- cgroup v1 --- - -func TestCgroupV1_NormalFraction(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - writeFile(t, dir, "sys/fs/cgroup/memory/memory.usage_in_bytes", "200") - writeFile(t, dir, "sys/fs/cgroup/memory/memory.limit_in_bytes", "400") - - frac, ok, err := cgroupV1UsedFraction() + // Second immediate call should use cache + frac2, err := getCachedMemoryUsage() if err != nil { - t.Fatalf("err: %v", err) + t.Fatalf("second call failed: %v", err) } - if !ok { - t.Fatalf("expected ok") + if frac2 != 0.5 { + t.Fatalf("expected 0.5, got %v", frac2) } - if got, want := frac, 0.5; got != want { - t.Fatalf("frac=%v want=%v", got, want) + if callCount != 1 { + t.Fatalf("expected still 1 call (cached), got %d", callCount) } -} - -func TestCgroupV1_HugeLimitMeansNoLimit(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - writeFile(t, dir, "sys/fs/cgroup/memory/memory.usage_in_bytes", "42") - // > 1<<60 - writeFile(t, dir, "sys/fs/cgroup/memory/memory.limit_in_bytes", strconv.FormatUint((1<<60)+1, 10)) + // Simulate cache expiration + memoryUsageCache.Lock() + memoryUsageCache.lastChecked = time.Now().Add(-memoryCheckInterval - time.Millisecond) + memoryUsageCache.Unlock() - _, ok, err := cgroupV1UsedFraction() + // Next call should invoke the function again + frac3, err := getCachedMemoryUsage() if err != nil { - t.Fatalf("err: %v", err) + t.Fatalf("third call failed: %v", err) } - if ok { - t.Fatalf("expected ok=false for huge limit (no limit)") - } -} - -func TestCgroupV1_MissingFilesOrZeroLimit(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - - // Only usage present -> not ok - writeFile(t, dir, "sys/fs/cgroup/memory/memory.usage_in_bytes", "7") - _, ok, err := cgroupV1UsedFraction() - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("expected ok=false when limit missing") - } - - // Now add zero limit -> not ok - writeFile(t, dir, "sys/fs/cgroup/memory/memory.limit_in_bytes", "0") - _, ok, err = cgroupV1UsedFraction() - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("expected ok=false for zero limit") - } -} - -// --- /proc/meminfo fallback --- - -func TestHostMeminfo_UsesMemAvailableWhenPresent(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - - // MemTotal and MemAvailable are in kB - meminfo := strings.Join([]string{ - "MemTotal: 1000000 kB", - "MemAvailable: 250000 kB", - "Buffers: 10000 kB", - "Cached: 20000 kB", - "MemFree: 50000 kB", - }, "\n") - writeFile(t, dir, "proc/meminfo", meminfo) - - frac, err := hostMeminfoUsedFraction() - if err != nil { - t.Fatalf("err: %v", err) + if frac3 != 0.5 { + t.Fatalf("expected 0.5, got %v", frac3) } - - // used = total - available = 1000000 - 250000 = 750000 => 0.75 - if got, want := frac, 0.75; got != want { - t.Fatalf("frac=%v want=%v", got, want) + if callCount != 2 { + t.Fatalf("expected 2 calls (cache expired), got %d", callCount) } } -func TestHostMeminfo_FallbackWithoutMemAvailable(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - - meminfo := strings.Join([]string{ - "MemTotal: 1000000 kB", - "MemFree: 100000 kB", - "Buffers: 50000 kB", - "Cached: 150000 kB", - }, "\n") - writeFile(t, dir, "proc/meminfo", meminfo) - - frac, err := hostMeminfoUsedFraction() +// TestGetSystemMemoryUsedFraction_Integration verifies that the actual implementation +// returns a valid value on the current platform. +func TestGetSystemMemoryUsedFraction_Integration(t *testing.T) { + fraction, err := getSystemMemoryUsedFraction() if err != nil { - t.Fatalf("err: %v", err) + t.Fatalf("getSystemMemoryUsedFraction() failed: %v", err) } - // approxAvailable = 100000 + 50000 + 150000 = 300000 - // used = 1000000 - 300000 = 700000 => 0.7 - if got, want := frac, 0.7; got != want { - t.Fatalf("frac=%v want=%v", got, want) + if fraction < 0 || fraction > 1 { + t.Fatalf("memory fraction out of range: got %v, want 0.0-1.0", fraction) } -} - -func TestHostMeminfo_Errors(t *testing.T) { - dir := t.TempDir() - restore := saveAndRedirectPaths(t, dir) - defer restore() - - // Missing file - if _, err := hostMeminfoUsedFraction(); err == nil { - t.Fatalf("expected error when /proc/meminfo missing") - } - - // Present but missing MemTotal - writeFile(t, dir, "proc/meminfo", "MemFree: 1 kB\n") - if _, err := hostMeminfoUsedFraction(); err == nil { - t.Fatalf("expected error when MemTotal missing") - } -} - -// --- read helpers --- -func TestReadUint64FileIfExists(t *testing.T) { - dir := t.TempDir() - - // Missing -> ok=false, err=nil - if _, ok, err := readUint64FileIfExists(filepath.Join(dir, "nope")); err != nil || ok { - t.Fatalf("expected ok=false, err=nil for missing file; got ok=%v err=%v", ok, err) - } - - // Present & valid - p := writeFile(t, dir, "n.txt", "123\n") - v, ok, err := readUint64FileIfExists(p) - if err != nil || !ok || v != 123 { - t.Fatalf("got v=%d ok=%v err=%v; want 123,true,nil", v, ok, err) - } - - // Present & invalid -> ok=false, err=nil - p = writeFile(t, dir, "bad.txt", "not-a-number") - if _, ok, err := readUint64FileIfExists(p); err != nil || ok { - t.Fatalf("expected ok=false, err=nil for invalid number; got ok=%v err=%v", ok, err) - } -} - -func TestReadStringFileIfExists(t *testing.T) { - dir := t.TempDir() - - if _, ok, err := readStringFileIfExists(filepath.Join(dir, "nope")); err != nil || ok { - t.Fatalf("expected ok=false, err=nil for missing file; got ok=%v err=%v", ok, err) - } - - p := writeFile(t, dir, "s.txt", " hello \n") - s, ok, err := readStringFileIfExists(p) - if err != nil || !ok || strings.TrimSpace(s) != "hello" { - t.Fatalf("got s=%q ok=%v err=%v; want 'hello',true,nil", s, ok, err) - } + t.Logf("Current system memory usage: %.2f%%", fraction*100) } diff --git a/pkg/spooledtempfile/spooled.go b/pkg/spooledtempfile/spooled.go index f351b29..bf917c6 100644 --- a/pkg/spooledtempfile/spooled.go +++ b/pkg/spooledtempfile/spooled.go @@ -7,8 +7,6 @@ import ( "io" "log" "os" - "sync" - "time" "github.com/valyala/bytebufferpool" ) @@ -20,18 +18,6 @@ const ( MaxInMemorySize = 1024 * 1024 // DefaultMaxRAMUsageFraction is the default fraction of system RAM above which we'll force spooling to disk DefaultMaxRAMUsageFraction = 0.50 - // memoryCheckInterval defines how often we check system memory usage. - memoryCheckInterval = 500 * time.Millisecond -) - -type globalMemoryCache struct { - sync.Mutex - lastChecked time.Time - lastFraction float64 -} - -var ( - memoryUsageCache = &globalMemoryCache{} ) // ReaderAt is the interface for ReadAt - read at position, without moving pointer. @@ -269,22 +255,3 @@ func (s *spooledTempFile) isSystemMemoryUsageHigh() bool { } return usedFraction >= s.maxRAMUsageFraction } - -func getCachedMemoryUsage() (float64, error) { - memoryUsageCache.Lock() - defer memoryUsageCache.Unlock() - - if time.Since(memoryUsageCache.lastChecked) < memoryCheckInterval { - return memoryUsageCache.lastFraction, nil - } - - fraction, err := getSystemMemoryUsedFraction() - if err != nil { - return 0, err - } - - memoryUsageCache.lastChecked = time.Now() - memoryUsageCache.lastFraction = fraction - - return fraction, nil -} diff --git a/pkg/spooledtempfile/spooled_test.go b/pkg/spooledtempfile/spooled_test.go index cd1aefb..81d8b31 100644 --- a/pkg/spooledtempfile/spooled_test.go +++ b/pkg/spooledtempfile/spooled_test.go @@ -11,12 +11,37 @@ import ( "testing" ) +// mockMemoryUsage mocks system memory to the specified fraction for the duration of the test. +// It uses t.Cleanup to automatically restore the original function and cache state. +// fraction should be between 0.0 (0% used) and 1.0 (100% used). +func mockMemoryUsage(t *testing.T, fraction float64) { + t.Helper() + + // Save original function + originalFn := getSystemMemoryUsedFraction + + // Reset cache and mock function + ResetMemoryCache() + getSystemMemoryUsedFraction = func() (float64, error) { + return fraction, nil + } + + // Auto-restore on test completion + t.Cleanup(func() { + getSystemMemoryUsedFraction = originalFn + // Ensure global cache is clean to prevent state pollution to other test packages + ResetMemoryCache() + }) +} + func generateTestDataInKB(size int) []byte { return bytes.Repeat([]byte("A"), size*1024) } // TestInMemoryBasic writes data below threshold and verifies it remains in memory. func TestInMemoryBasic(t *testing.T) { + mockMemoryUsage(t, 0.30) // Mock low memory usage (30%) + spool := NewSpooledTempFile("test", os.TempDir(), 100, false, -1) defer spool.Close() @@ -66,6 +91,8 @@ func TestInMemoryBasic(t *testing.T) { // TestThresholdCrossing writes enough data to switch from in-memory to disk. func TestThresholdCrossing(t *testing.T) { + mockMemoryUsage(t, 0.30) // Mock low memory usage (30%) + spool := NewSpooledTempFile("test", os.TempDir(), 64*1024, false, -1) defer spool.Close() @@ -114,6 +141,7 @@ func TestThresholdCrossing(t *testing.T) { } // TestForceOnDisk checks the fullOnDisk parameter. +// Note: This test doesn't mock memory because fullOnDisk=true forces disk behavior regardless. func TestForceOnDisk(t *testing.T) { spool := NewSpooledTempFile("test", os.TempDir(), 64*1024, true, -1) defer spool.Close() @@ -142,6 +170,8 @@ func TestForceOnDisk(t *testing.T) { // TestReadAtAndSeekInMemory tests seeking and ReadAt on an in-memory spool. func TestReadAtAndSeekInMemory(t *testing.T) { + mockMemoryUsage(t, 0.30) // Mock low memory usage (30%) + spool := NewSpooledTempFile("test", os.TempDir(), 64*1024, false, -1) defer spool.Close() @@ -185,6 +215,7 @@ func TestReadAtAndSeekInMemory(t *testing.T) { } // TestReadAtAndSeekOnDisk tests seeking and ReadAt on a spool that has switched to disk. +// Note: This test doesn't mock memory because it writes 65KB to intentionally cross the 64KB threshold. func TestReadAtAndSeekOnDisk(t *testing.T) { spool := NewSpooledTempFile("test", os.TempDir(), 64*1024, false, -1) defer spool.Close() @@ -255,6 +286,8 @@ func TestWriteAfterReadPanic(t *testing.T) { // TestCloseInMemory checks closing while still in-memory. func TestCloseInMemory(t *testing.T) { + mockMemoryUsage(t, 0.30) // Mock low memory usage (30%) + spool := NewSpooledTempFile("test", os.TempDir(), 64*1024, false, -1) _, err := spool.Write([]byte("Small data")) @@ -284,6 +317,7 @@ func TestCloseInMemory(t *testing.T) { } // TestCloseOnDisk checks closing after spool has switched to disk. +// Note: This test doesn't mock memory because it writes 65KB to intentionally cross the threshold. func TestCloseOnDisk(t *testing.T) { spool := NewSpooledTempFile("test", os.TempDir(), 64*1024, false, -1) @@ -327,6 +361,8 @@ func TestCloseOnDisk(t *testing.T) { // TestLen verifies Len() for both in-memory and on-disk states. func TestLen(t *testing.T) { + mockMemoryUsage(t, 0.30) // Mock low memory usage (30%) + spool := NewSpooledTempFile("test", os.TempDir(), 64*1024, false, -1) defer spool.Close() @@ -351,6 +387,8 @@ func TestLen(t *testing.T) { // TestFileName checks correctness of FileName in both modes. func TestFileName(t *testing.T) { + mockMemoryUsage(t, 0.30) // Mock low memory usage (30%) + spool := NewSpooledTempFile("testprefix", os.TempDir(), 64*1024, false, -1) defer spool.Close() @@ -383,17 +421,7 @@ func TestFileName(t *testing.T) { // TestSkipInMemoryAboveRAMUsage verifies that if `isSystemMemoryUsageHigh()` // returns true, the spool goes directly to disk even for small writes. func TestSkipInMemoryAboveRAMUsage(t *testing.T) { - memoryUsageCache = &globalMemoryCache{} - // Save the old function so we can restore it later - oldGetSystemMemoryUsedFraction := getSystemMemoryUsedFraction - // Force system memory usage to appear above 50% - getSystemMemoryUsedFraction = func() (float64, error) { - return 0.60, nil // 60% used => above the 50% threshold - } - // Restore after test - defer func() { - getSystemMemoryUsedFraction = oldGetSystemMemoryUsedFraction - }() + mockMemoryUsage(t, 0.60) // Mock memory usage at 60% (above 50% threshold) // Even though threshold is large (e.g. 1MB), because our mock usage is 60%, // spool should skip memory and go straight to disk. @@ -427,9 +455,67 @@ func TestSkipInMemoryAboveRAMUsage(t *testing.T) { } } +// TestMemoryThresholdBelowLimit verifies behavior when memory is just below threshold (49%). +func TestMemoryThresholdBelowLimit(t *testing.T) { + mockMemoryUsage(t, 0.49) // Mock memory at 49% (below 50% threshold) + + spool := NewSpooledTempFile("test", os.TempDir(), 1024*1024, false, 0.50) + defer spool.Close() + + data := []byte("Should stay in memory") + _, err := spool.Write(data) + if err != nil { + t.Fatalf("Write error: %v", err) + } + + // Should stay in memory since 49% < 50% + if spool.FileName() != "" { + t.Errorf("Expected spool to stay in memory (49%% < 50%%), but got file: %s", spool.FileName()) + } +} + +// TestMemoryThresholdAtLimit verifies behavior when memory is exactly at threshold (50%). +func TestMemoryThresholdAtLimit(t *testing.T) { + mockMemoryUsage(t, 0.50) // Mock memory at exactly 50% (at threshold) + + spool := NewSpooledTempFile("test", os.TempDir(), 1024*1024, false, 0.50) + defer spool.Close() + + data := []byte("Should go to disk") + _, err := spool.Write(data) + if err != nil { + t.Fatalf("Write error: %v", err) + } + + // Should go to disk since 50% >= 50% + if spool.FileName() == "" { + t.Error("Expected spool to go to disk (50%% >= 50%%), but stayed in memory") + } +} + +// TestMemoryThresholdAboveLimit verifies behavior when memory is above threshold (51%). +func TestMemoryThresholdAboveLimit(t *testing.T) { + mockMemoryUsage(t, 0.51) // Mock memory at 51% (above 50% threshold) + + spool := NewSpooledTempFile("test", os.TempDir(), 1024*1024, false, 0.50) + defer spool.Close() + + data := []byte("Should go to disk") + _, err := spool.Write(data) + if err != nil { + t.Fatalf("Write error: %v", err) + } + + // Should go to disk since 51% > 50% + if spool.FileName() == "" { + t.Error("Expected spool to go to disk (51%% > 50%%), but stayed in memory") + } +} + // TestBufferGrowthWithinLimits verifies that the buffer grows dynamically but never exceeds MaxInMemorySize. func TestBufferGrowthWithinLimits(t *testing.T) { - memoryUsageCache = &globalMemoryCache{} + mockMemoryUsage(t, 0.30) // Mock low memory usage (30%) + spool := NewSpooledTempFile("test", os.TempDir(), 128*1024, false, -1) defer spool.Close() @@ -473,6 +559,8 @@ func TestBufferGrowthWithinLimits(t *testing.T) { // TestPoolBehavior verifies that buffers exceeding InitialBufferSize are not returned to the pool. func TestPoolBehavior(t *testing.T) { + mockMemoryUsage(t, 0.30) // Mock low memory to ensure in-memory pooling behavior is tested + spool := NewSpooledTempFile("test", os.TempDir(), 150*1024, false, -1) defer spool.Close() @@ -503,6 +591,8 @@ func TestPoolBehavior(t *testing.T) { } } +// TestBufferGrowthBeyondNewCap verifies buffer behavior when growth exceeds threshold. +// Note: This test doesn't mock memory because it writes 101KB to intentionally exceed the 100KB threshold. func TestBufferGrowthBeyondNewCap(t *testing.T) { spool := NewSpooledTempFile("test", os.TempDir(), 100*1024, false, -1) defer spool.Close() @@ -546,6 +636,8 @@ func TestBufferGrowthBeyondNewCap(t *testing.T) { } } +// TestSpoolingWhenIOCopy verifies spooling behavior with io.Copy for large data. +// Note: This test doesn't mock memory because it writes 500KB to intentionally trigger disk spooling. func TestSpoolingWhenIOCopy(t *testing.T) { spool := NewSpooledTempFile("test", os.TempDir(), 100*1024, false, -1) defer spool.Close() diff --git a/read.go b/read.go index 6e179d0..bf107d8 100644 --- a/read.go +++ b/read.go @@ -263,6 +263,10 @@ func (r *Reader) ReadRecord(opts ...ReadOpts) (*Record, error) { } return nil, fmt.Errorf("reading WARC version: %w", err) } + // Copy the WARC version before parsing headers since readUntilDelim reuses + // its backing buffer via a sync.Pool. Without copying, subsequent calls may + // overwrite the data referenced by warcVer, resulting in corrupted versions. + warcVersion := string(warcVer) header := NewHeader() for { @@ -331,7 +335,7 @@ func (r *Reader) ReadRecord(opts ...ReadOpts) (*Record, error) { record := &Record{ Header: header, Content: buf, - Version: string(warcVer), + Version: warcVersion, Offset: offset, Size: size, } diff --git a/smoke_test.go b/smoke_test.go new file mode 100644 index 0000000..6411001 --- /dev/null +++ b/smoke_test.go @@ -0,0 +1,154 @@ +package warc + +import ( + "io" + "os" + "strconv" + "testing" +) + +// TestSmokeWARCFormatRegression validates that the WARC format remains consistent +// by checking a frozen reference file (testdata/test.warc.gz) against known-good values. +// +// This test serves as a regression detector for WARC format changes, complementing the +// dynamic tests in client_test.go. It addresses the concern that byte-level format +// changes should be explicitly validated against a known-good snapshot. +// +// If this test fails, it indicates that either: +// 1. The WARC writing logic has changed in a way that affects the format +// 2. The reference file has been modified +// 3. There's a bug in the record serialization +func TestSmokeWARCFormatRegression(t *testing.T) { + const testFile = "testdata/warcs/test.warc.gz" + + // Expected file-level metrics + const expectedFileSize = 22350 // bytes (compressed) + const expectedTotalRecords = 3 + const expectedTotalContentLength = 22083 // sum of all Content-Length values + + // Expected record-level metrics + // These values were extracted from a known-good WARC file and serve as + // a snapshot of correct format behavior. + expectedRecords := []struct { + warcType string + contentLength int64 + blockDigest string + payloadDigest string // only for response records + targetURI string // only for response records + }{ + { + warcType: "warcinfo", + contentLength: 143, + blockDigest: "sha1:IYWIATZSPEOF7U5W7VGGJOSQTIWUDXQ6", + }, + { + warcType: "request", + contentLength: 110, + blockDigest: "sha1:JNDMG56JVTVVOQSDQRD25XWTGMRQAQDB", + }, + { + warcType: "response", + contentLength: 21830, + blockDigest: "sha1:LCKC4TTRSBWYHGYT5P22ON4DWY65WHDZ", + targetURI: "https://apis.google.com/js/platform.js", + }, + } + + // Validate file size + stat, err := os.Stat(testFile) + if err != nil { + t.Fatalf("failed to stat test file: %v", err) + } + if stat.Size() != expectedFileSize { + t.Errorf("file size mismatch: expected %d bytes, got %d bytes", expectedFileSize, stat.Size()) + } + + // Open and read WARC file + file, err := os.Open(testFile) + if err != nil { + t.Fatalf("failed to open test file: %v", err) + } + defer file.Close() + + reader, err := NewReader(file) + if err != nil { + t.Fatalf("failed to create WARC reader: %v", err) + } + + var recordCount int + var totalContentLength int64 + + // Read and validate each record + for recordCount < expectedTotalRecords { + record, err := reader.ReadRecord() + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("failed to read record %d: %v", recordCount+1, err) + } + if record == nil { + break + } + + expected := expectedRecords[recordCount] + + // Validate WARC-Type + warcType := record.Header.Get("WARC-Type") + if warcType != expected.warcType { + t.Errorf("record %d: WARC-Type mismatch: expected %q, got %q", + recordCount+1, expected.warcType, warcType) + } + + // Validate Content-Length + contentLengthStr := record.Header.Get("Content-Length") + contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) + if err != nil { + t.Errorf("record %d: failed to parse Content-Length %q: %v", + recordCount+1, contentLengthStr, err) + } else { + if contentLength != expected.contentLength { + t.Errorf("record %d: Content-Length mismatch: expected %d, got %d", + recordCount+1, expected.contentLength, contentLength) + } + totalContentLength += contentLength + } + + // Validate WARC-Block-Digest + blockDigest := record.Header.Get("WARC-Block-Digest") + if blockDigest != expected.blockDigest { + t.Errorf("record %d: WARC-Block-Digest mismatch: expected %q, got %q", + recordCount+1, expected.blockDigest, blockDigest) + } + + // Validate response-specific fields + if warcType == "response" { + if expected.targetURI != "" { + targetURI := record.Header.Get("WARC-Target-URI") + if targetURI != expected.targetURI { + t.Errorf("record %d: WARC-Target-URI mismatch: expected %q, got %q", + recordCount+1, expected.targetURI, targetURI) + } + } + } + + // Close record content + if err := record.Content.Close(); err != nil { + t.Errorf("record %d: failed to close content: %v", recordCount+1, err) + } + + recordCount++ + } + + // Validate total record count + if recordCount != expectedTotalRecords { + t.Errorf("total record count mismatch: expected %d, got %d", + expectedTotalRecords, recordCount) + } + + // Validate total content length + if totalContentLength != expectedTotalContentLength { + t.Errorf("total content length mismatch: expected %d bytes, got %d bytes", + expectedTotalContentLength, totalContentLength) + } +} From 8541bb927b89d0482671981da67c5804b806ee5b Mon Sep 17 00:00:00 2001 From: Corentin Barreau Date: Thu, 13 Nov 2025 11:25:12 +0100 Subject: [PATCH 06/10] Add configurable DNS parallelization & round-robin (#156) --- README.md | 26 +++ client.go | 3 +- dialer.go | 109 +++++++++--- dns.go | 159 +++++++++++++---- dns_test.go | 482 +++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 4 +- read_test.go | 10 ++ 7 files changed, 738 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 03017ed..7c34de6 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,32 @@ func main() { } ``` +### DNS Resolution and Proxy Behavior + +The library handles DNS resolution differently depending on the connection type: + +#### Direct Connections (No Proxy) +- DNS is resolved locally using configured DNS servers +- DNS queries and responses are archived in WARC files as `resource` records +- Resolved IP addresses are cached with configurable TTL + +#### Local DNS Proxies (`socks5://`, `socks4://`) +- DNS is resolved locally by gowarc +- DNS records are archived to WARC files +- Resolved IP addresses are sent to the proxy +- Only one DNS query is made (no duplicate resolution) + +#### Remote DNS Proxies (`socks5h://`, `socks4a://`, `http://`, `https://`) +- **DNS archiving is skipped** to prevent privacy leaks +- Hostnames are sent directly to the proxy +- The proxy handles DNS resolution on its end +- Trade-offs: + - ✅ **Privacy**: No local DNS queries that could expose browsing activity + - ✅ **Accuracy**: WARC reflects the actual connection (no potential DNS mismatch) + - ⚠️ **No DNS WARC records**: DNS information is not archived for these connections + +**Important for Privacy**: When using `socks5h://` or other remote DNS proxies, your local DNS servers will not see any queries for the target domains, maintaining better privacy and anonymity. + ## CLI Tools In addition to the Go library, gowarc provides several command-line utilities for working with WARC files: diff --git a/client.go b/client.go index da3e624..9591e55 100644 --- a/client.go +++ b/client.go @@ -25,6 +25,7 @@ type HTTPClientSettings struct { DNSResolutionTimeout time.Duration DNSRecordsTTL time.Duration DNSCacheSize int + DNSConcurrency int TLSHandshakeTimeout time.Duration ConnReadDeadline time.Duration MaxReadBeforeTruncate int @@ -215,7 +216,7 @@ func NewWARCWritingHTTPClient(HTTPClientSettings HTTPClientSettings) (httpClient httpClient.ConnReadDeadline = HTTPClientSettings.ConnReadDeadline // Configure custom dialer / transport - customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout, HTTPClientSettings.DNSRecordsTTL, HTTPClientSettings.DNSResolutionTimeout, HTTPClientSettings.DNSCacheSize, HTTPClientSettings.DNSServers, HTTPClientSettings.DisableIPv4, HTTPClientSettings.DisableIPv6) + customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout, HTTPClientSettings.DNSRecordsTTL, HTTPClientSettings.DNSResolutionTimeout, HTTPClientSettings.DNSCacheSize, HTTPClientSettings.DNSServers, HTTPClientSettings.DNSConcurrency, HTTPClientSettings.DisableIPv4, HTTPClientSettings.DisableIPv6) if err != nil { return nil, err } diff --git a/dialer.go b/dialer.go index c782c3c..c642cc5 100644 --- a/dialer.go +++ b/dialer.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -61,15 +62,23 @@ func WithWrappedConnection(ctx context.Context, wrappedConnChan chan *CustomConn return context.WithValue(ctx, ContextKeyWrappedConn, wrappedConnChan) } +// dnsExchanger is an interface for DNS clients that can exchange messages +type dnsExchanger interface { + ExchangeContext(ctx context.Context, m *dns.Msg, address string) (r *dns.Msg, rtt time.Duration, err error) +} + type customDialer struct { - proxyDialer proxy.ContextDialer - client *CustomHTTPClient - DNSConfig *dns.ClientConfig - DNSClient *dns.Client - DNSRecords *otter.Cache[string, net.IP] + proxyDialer proxy.ContextDialer + proxyNeedsHostname bool // true if proxy requires hostname (socks5h, http), false if can use IP (socks5) + client *CustomHTTPClient + DNSConfig *dns.ClientConfig + DNSClient dnsExchanger + DNSRecords *otter.Cache[string, net.IP] net.Dialer - disableIPv4 bool - disableIPv6 bool + disableIPv4 bool + disableIPv6 bool + dnsConcurrency int + dnsRoundRobinIndex atomic.Uint32 } var emptyPayloadDigests = []string{ @@ -79,13 +88,14 @@ var emptyPayloadDigests = []string{ "blake3:af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262", } -func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout, DNSRecordsTTL, DNSResolutionTimeout time.Duration, DNSCacheSize int, DNSServers []string, disableIPv4, disableIPv6 bool) (d *customDialer, err error) { +func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout, DNSRecordsTTL, DNSResolutionTimeout time.Duration, DNSCacheSize int, DNSServers []string, DNSConcurrency int, disableIPv4, disableIPv6 bool) (d *customDialer, err error) { d = new(customDialer) d.Timeout = DialTimeout d.client = httpClient d.disableIPv4 = disableIPv4 d.disableIPv6 = disableIPv6 + d.dnsConcurrency = DNSConcurrency DNScache, err := otter.MustBuilder[string, net.IP](DNSCacheSize). // CollectStats(). // Uncomment this line to enable stats collection, can be useful later on @@ -123,6 +133,12 @@ func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout, } d.proxyDialer = proxyDialer.(proxy.ContextDialer) + + // Determine if this proxy requires hostname (remote DNS) or can use IP (local DNS) + // Proxies with remote DNS: socks5h, socks4a, http, https + // Proxies with local DNS: socks5, socks4 + d.proxyNeedsHostname = u.Scheme == "socks5h" || u.Scheme == "socks4a" || + u.Scheme == "http" || u.Scheme == "https" } return d, nil @@ -219,26 +235,50 @@ func (d *customDialer) CustomDialContext(ctx context.Context, network, address s return nil, errors.New("no supported network type available") } - IP, _, err := d.archiveDNS(ctx, address) - if err != nil { - return nil, err + var dialAddr string + var IP net.IP + + if d.proxyDialer != nil && d.proxyNeedsHostname { + // Remote DNS proxy (socks5h, socks4a, http, https) + // Skip DNS archiving to avoid privacy leak and ensure accuracy. + // The proxy will handle DNS resolution on its end, and we don't want to: + // 1. Leak DNS queries to local DNS servers (defeats purpose of socks5h) + // 2. Archive potentially incorrect DNS results (local DNS may differ from proxy's DNS) + dialAddr = address + } else { + // Direct connection or local DNS proxy (socks5, socks4) + // Archive DNS and use resolved IP + IP, _, err = d.archiveDNS(ctx, address) + if err != nil { + return nil, err + } + + // Extract port from address for IP:port construction + var port string + _, port, err = net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("failed to extract port from address %s: %w", address, err) + } + + dialAddr = net.JoinHostPort(IP.String(), port) } if d.proxyDialer != nil { - conn, err = d.proxyDialer.DialContext(ctx, network, address) + conn, err = d.proxyDialer.DialContext(ctx, network, dialAddr) } else { if d.client.randomLocalIP { localAddr := getLocalAddr(network, IP) if localAddr != nil { - if network == "tcp" || network == "tcp4" || network == "tcp6" { + switch network { + case "tcp", "tcp4", "tcp6": d.LocalAddr = localAddr.(*net.TCPAddr) - } else if network == "udp" || network == "udp4" || network == "udp6" { + case "udp", "udp4", "udp6": d.LocalAddr = localAddr.(*net.UDPAddr) } } } - conn, err = d.DialContext(ctx, network, address) + conn, err = d.DialContext(ctx, network, dialAddr) } if err != nil { @@ -259,28 +299,53 @@ func (d *customDialer) CustomDialTLSContext(ctx context.Context, network, addres return nil, errors.New("no supported network type available") } - IP, _, err := d.archiveDNS(ctx, address) - if err != nil { - return nil, err + var dialAddr string + var IP net.IP + var err error + + if d.proxyDialer != nil && d.proxyNeedsHostname { + // Remote DNS proxy (socks5h, socks4a, http, https) + // Skip DNS archiving to avoid privacy leak and ensure accuracy. + // The proxy will handle DNS resolution on its end, and we don't want to: + // 1. Leak DNS queries to local DNS servers (defeats purpose of socks5h) + // 2. Archive potentially incorrect DNS results (local DNS may differ from proxy's DNS) + dialAddr = address + } else { + // Direct connection or local DNS proxy (socks5, socks4) + // Archive DNS and use resolved IP + IP, _, err = d.archiveDNS(ctx, address) + if err != nil { + return nil, err + } + + // Extract port from address for IP:port construction + var port string + _, port, err = net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("failed to extract port from address %s: %w", address, err) + } + + dialAddr = net.JoinHostPort(IP.String(), port) } var plainConn net.Conn if d.proxyDialer != nil { - plainConn, err = d.proxyDialer.DialContext(ctx, network, address) + plainConn, err = d.proxyDialer.DialContext(ctx, network, dialAddr) } else { if d.client.randomLocalIP { localAddr := getLocalAddr(network, IP) if localAddr != nil { - if network == "tcp" || network == "tcp4" || network == "tcp6" { + switch network { + case "tcp", "tcp4", "tcp6": d.LocalAddr = localAddr.(*net.TCPAddr) - } else if network == "udp" || network == "udp4" || network == "udp6" { + case "udp", "udp4", "udp6": d.LocalAddr = localAddr.(*net.UDPAddr) } } } - plainConn, err = d.DialContext(ctx, network, address) + plainConn, err = d.DialContext(ctx, network, dialAddr) } if err != nil { diff --git a/dns.go b/dns.go index 07b963a..4d3a51e 100644 --- a/dns.go +++ b/dns.go @@ -9,8 +9,6 @@ import ( "github.com/miekg/dns" ) -const maxFallbackDNSServers = 3 - func (d *customDialer) archiveDNS(ctx context.Context, address string) (resolvedIP net.IP, cached bool, err error) { // Get the address without the port if there is one address, _, err = net.SplitHostPort(address) @@ -29,39 +27,14 @@ func (d *customDialer) archiveDNS(ctx context.Context, address string) (resolved return cachedIP, true, nil } - var wg sync.WaitGroup - var ipv4, ipv6 net.IP - var errA, errAAAA error - if len(d.DNSConfig.Servers) == 0 { return nil, false, fmt.Errorf("no DNS servers configured") } - fallbackServers := min(maxFallbackDNSServers, len(d.DNSConfig.Servers)-1) - - for DNSServer := 0; DNSServer <= fallbackServers; DNSServer++ { - if !d.disableIPv4 { - wg.Add(1) - go func() { - defer wg.Done() - ipv4, errA = d.lookupIP(ctx, address, dns.TypeA, DNSServer) - }() - } - - if !d.disableIPv6 { - wg.Add(1) - go func() { - defer wg.Done() - ipv6, errAAAA = d.lookupIP(ctx, address, dns.TypeAAAA, DNSServer) - }() - } - - wg.Wait() + var ipv4, ipv6 net.IP + var errA, errAAAA error - if errA == nil || errAAAA == nil { - break - } - } + ipv4, ipv6, errA, errAAAA = d.concurrentDNSLookup(ctx, address, len(d.DNSConfig.Servers)) if errA != nil && errAAAA != nil { return nil, false, fmt.Errorf("failed to resolve DNS: A error: %v, AAAA error: %v", errA, errAAAA) } @@ -82,6 +55,132 @@ func (d *customDialer) archiveDNS(ctx context.Context, address string) (resolved return nil, false, fmt.Errorf("no suitable IP address found for %s", address) } +// concurrentDNSLookup tries DNS servers with configurable concurrency +// - dnsConcurrency <= 1: sequential (one server at a time) +// - dnsConcurrency > 1: that many servers concurrently +// - dnsConcurrency == -1: all servers at once (unlimited) +// Implements early cancellation: stops querying once results are found +func (d *customDialer) concurrentDNSLookup(ctx context.Context, address string, maxServers int) (ipv4, ipv6 net.IP, errA, errAAAA error) { + type result struct { + ip net.IP + err error + recordType uint16 + } + + // Determine effective concurrency + concurrency := d.dnsConcurrency + if concurrency == -1 { + concurrency = maxServers // Unlimited = all servers + } else if concurrency <= 0 { + concurrency = 1 // Default to sequential + } + + // Create cancellable context for early termination + workerCtx, cancel := context.WithCancel(ctx) + defer cancel() + + resultChan := make(chan result, maxServers*2) + serverChan := make(chan int, maxServers) + var wg sync.WaitGroup + + // Fill server queue with round-robin starting index + // Atomically increment and get the starting position + startIdx := int(d.dnsRoundRobinIndex.Add(1)-1) % maxServers + for i := range maxServers { + serverIdx := (startIdx + i) % maxServers + serverChan <- serverIdx + } + close(serverChan) + + // Helper to check if we have all needed results + haveAllResults := func() bool { + if !d.disableIPv4 && ipv4 == nil { + return false + } + if !d.disableIPv6 && ipv6 == nil { + return false + } + return true + } + + // Launch worker goroutines (limited by concurrency) + for i := 0; i < concurrency && i < maxServers; i++ { + wg.Go(func() { + for serverIdx := range serverChan { + // Check if context was cancelled before starting queries + select { + case <-workerCtx.Done(): + return + default: + } + + // Query both A and AAAA for this server + if !d.disableIPv4 { + ip, err := d.lookupIP(workerCtx, address, dns.TypeA, serverIdx) + select { + case resultChan <- result{ip: ip, err: err, recordType: dns.TypeA}: + case <-workerCtx.Done(): + return + } + } + if !d.disableIPv6 { + ip, err := d.lookupIP(workerCtx, address, dns.TypeAAAA, serverIdx) + select { + case resultChan <- result{ip: ip, err: err, recordType: dns.TypeAAAA}: + case <-workerCtx.Done(): + return + } + } + } + }) + } + + // Close result channel when all workers complete + go func() { + wg.Wait() + close(resultChan) + }() + + // Collect results with early termination + var ipv4Errors, ipv6Errors []error + for res := range resultChan { + if res.err == nil { + if res.recordType == dns.TypeA && ipv4 == nil { + ipv4 = res.ip + } else if res.recordType == dns.TypeAAAA && ipv6 == nil { + ipv6 = res.ip + } + + // Early termination: if we have all results, cancel workers + if haveAllResults() { + cancel() + // Drain remaining results to prevent worker blocking + go func() { + for range resultChan { + } + }() + break + } + } else { + if res.recordType == dns.TypeA { + ipv4Errors = append(ipv4Errors, res.err) + } else { + ipv6Errors = append(ipv6Errors, res.err) + } + } + } + + // Set errors only if all queries of that type failed + if ipv4 == nil && len(ipv4Errors) > 0 { + errA = ipv4Errors[0] + } + if ipv6 == nil && len(ipv6Errors) > 0 { + errAAAA = ipv6Errors[0] + } + + return ipv4, ipv6, errA, errAAAA +} + func (d *customDialer) lookupIP(ctx context.Context, address string, recordType uint16, DNSServer int) (net.IP, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(address), recordType) diff --git a/dns_test.go b/dns_test.go index 6a94203..cd15015 100644 --- a/dns_test.go +++ b/dns_test.go @@ -5,6 +5,7 @@ import ( "errors" "net" "os" + "sync" "testing" "time" @@ -22,6 +23,107 @@ const ( target1 = "www.archive.org:443" ) +// mockDNSClient is a mock DNS client for testing that doesn't make real network calls +type mockDNSClient struct { + // responses maps server address to response config + responses map[string]mockDNSResponse + // callLog tracks which servers were called and in what order + callLog []string + callLogMu sync.Mutex + // delay simulates network latency per server + delays map[string]time.Duration +} + +type mockDNSResponse struct { + ipv4 net.IP + ipv6 net.IP + err error +} + +func newMockDNSClient() *mockDNSClient { + return &mockDNSClient{ + responses: make(map[string]mockDNSResponse), + callLog: []string{}, + delays: make(map[string]time.Duration), + } +} + +func (m *mockDNSClient) setResponse(serverAddr string, ipv4, ipv6 net.IP, err error) { + m.responses[serverAddr] = mockDNSResponse{ipv4: ipv4, ipv6: ipv6, err: err} +} + +func (m *mockDNSClient) setDelay(serverAddr string, delay time.Duration) { + m.delays[serverAddr] = delay +} + +func (m *mockDNSClient) getCallLog() []string { + m.callLogMu.Lock() + defer m.callLogMu.Unlock() + result := make([]string, len(m.callLog)) + copy(result, m.callLog) + return result +} + +func (m *mockDNSClient) ExchangeContext(ctx context.Context, msg *dns.Msg, address string) (*dns.Msg, time.Duration, error) { + // Log the call + m.callLogMu.Lock() + m.callLog = append(m.callLog, address) + m.callLogMu.Unlock() + + // Simulate delay if configured + if delay, ok := m.delays[address]; ok && delay > 0 { + select { + case <-time.After(delay): + case <-ctx.Done(): + return nil, 0, ctx.Err() + } + } + + // Get configured response + resp, ok := m.responses[address] + if !ok { + return nil, 0, errors.New("no mock response configured for " + address) + } + + if resp.err != nil { + return nil, 0, resp.err + } + + // Build DNS response message + r := new(dns.Msg) + r.SetReply(msg) + + // Determine record type being queried + if len(msg.Question) > 0 { + qtype := msg.Question[0].Qtype + if qtype == dns.TypeA && resp.ipv4 != nil { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: resp.ipv4, + } + r.Answer = append(r.Answer, rr) + } else if qtype == dns.TypeAAAA && resp.ipv6 != nil { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: resp.ipv6, + } + r.Answer = append(r.Answer, rr) + } + } + + return r, 0, nil +} + func newTestCustomDialer() (d *customDialer) { d = new(customDialer) @@ -44,6 +146,60 @@ func newTestCustomDialer() (d *customDialer) { return d } +// setupMock creates a dialer for mock-based tests without WARC writing +func setupMock() (*customDialer, func()) { + d := newTestCustomDialer() + + // Use a no-op client with a WARC writer channel that gets drained + d.client = &CustomHTTPClient{ + WARCWriter: make(chan *RecordBatch, 100), // Large buffer to prevent blocking + } + + // Drain the WARC writer channel to prevent blocking + stopDrain := make(chan bool) + var drainerWg sync.WaitGroup + drainerWg.Go(func() { + for { + select { + case batch := <-d.client.WARCWriter: + // Send feedback immediately to unblock writer + if batch != nil && batch.FeedbackChan != nil { + batch.FeedbackChan <- struct{}{} + } + case <-stopDrain: + // Drain remaining items + for { + select { + case batch := <-d.client.WARCWriter: + if batch != nil && batch.FeedbackChan != nil { + batch.FeedbackChan <- struct{}{} + } + default: + return + } + } + } + } + }) + + cleanup := func() { + // Wait for DNS operations to complete + time.Sleep(200 * time.Millisecond) + // Stop the drainer + close(stopDrain) + // Wait for drainer to finish + drainerWg.Wait() + // Close the cache + d.DNSRecords.Close() + // Now safe to close the channel + close(d.client.WARCWriter) + // Give otter cache goroutines time to shut down (matches original setup()) + time.Sleep(1 * time.Second) + } + + return d, cleanup +} + func setup(t *testing.T) (*customDialer, *CustomHTTPClient, func()) { var ( rotatorSettings = NewRotatorSettings() @@ -188,3 +344,329 @@ func TestDNSCaching(t *testing.T) { t.Error("Expected cached result") } } + +// TestDNSConcurrencySequential tests sequential DNS resolution (concurrency=1) +func TestDNSConcurrencySequential(t *testing.T) { + d, cleanup := setupMock() + defer cleanup() + + mock := newMockDNSClient() + d.DNSClient = mock + d.dnsConcurrency = 1 + + // Configure 3 servers where first fails, second succeeds + server1 := "1.1.1.1:53" + server2 := "2.2.2.2:53" + server3 := "3.3.3.3:53" + + mock.setResponse(server1, nil, nil, errors.New("server 1 failed")) + mock.setResponse(server2, net.ParseIP("192.0.2.1"), net.ParseIP("2001:db8::1"), nil) + mock.setResponse(server3, net.ParseIP("192.0.2.3"), net.ParseIP("2001:db8::3"), nil) + + d.DNSConfig.Servers = []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"} + d.DNSRecords.Delete(targetHost) + + IP, _, err := d.archiveDNS(context.Background(), target) + if err != nil { + t.Fatal(err) + } + + if IP == nil { + t.Fatal("Expected resolved IP") + } + + // With sequential mode, should have tried servers one at a time + callLog := mock.getCallLog() + if len(callLog) < 2 { + t.Errorf("Expected at least 2 calls (first fails, second succeeds), got %d", len(callLog)) + } + t.Logf("Call log: %v", callLog) +} + +// TestDNSConcurrencyParallel tests parallel DNS resolution with limited concurrency +func TestDNSConcurrencyParallel(t *testing.T) { + d, cleanup := setupMock() + defer cleanup() + + mock := newMockDNSClient() + d.DNSClient = mock + d.dnsConcurrency = 3 + + // Configure 3 servers, all succeed + server1 := "1.1.1.1:53" + server2 := "2.2.2.2:53" + server3 := "3.3.3.3:53" + + mock.setResponse(server1, net.ParseIP("192.0.2.1"), net.ParseIP("2001:db8::1"), nil) + mock.setResponse(server2, net.ParseIP("192.0.2.2"), net.ParseIP("2001:db8::2"), nil) + mock.setResponse(server3, net.ParseIP("192.0.2.3"), net.ParseIP("2001:db8::3"), nil) + + d.DNSConfig.Servers = []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"} + d.DNSRecords.Delete(targetHost) + + IP, _, err := d.archiveDNS(context.Background(), target) + if err != nil { + t.Fatal(err) + } + + if IP == nil { + t.Fatal("Expected resolved IP") + } + + // With parallel mode (concurrency=3), all 3 servers may be queried + callLog := mock.getCallLog() + t.Logf("Call log: %v (length: %d)", callLog, len(callLog)) +} + +// TestDNSConcurrencyUnlimited tests unlimited parallel DNS resolution +func TestDNSConcurrencyUnlimited(t *testing.T) { + d, cleanup := setupMock() + defer cleanup() + + mock := newMockDNSClient() + d.DNSClient = mock + d.dnsConcurrency = -1 // Unlimited + + // Configure 3 servers + server1 := "1.1.1.1:53" + server2 := "2.2.2.2:53" + server3 := "3.3.3.3:53" + + mock.setResponse(server1, net.ParseIP("192.0.2.1"), net.ParseIP("2001:db8::1"), nil) + mock.setResponse(server2, net.ParseIP("192.0.2.2"), net.ParseIP("2001:db8::2"), nil) + mock.setResponse(server3, net.ParseIP("192.0.2.3"), net.ParseIP("2001:db8::3"), nil) + + d.DNSConfig.Servers = []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"} + d.DNSRecords.Delete(targetHost) + + IP, _, err := d.archiveDNS(context.Background(), target) + if err != nil { + t.Fatal(err) + } + + if IP == nil { + t.Fatal("Expected resolved IP") + } + + // With unlimited concurrency, servers are queried in parallel + callLog := mock.getCallLog() + t.Logf("Call log: %v (length: %d)", callLog, len(callLog)) +} + +// TestDNSRoundRobin verifies round-robin DNS server selection +func TestDNSRoundRobin(t *testing.T) { + d, cleanup := setupMock() + defer cleanup() + + mock := newMockDNSClient() + d.DNSClient = mock + d.dnsConcurrency = 1 // Sequential for predictable ordering + + // Configure 3 servers + server1 := "1.1.1.1:53" + server2 := "2.2.2.2:53" + server3 := "3.3.3.3:53" + + mock.setResponse(server1, net.ParseIP("192.0.2.1"), net.ParseIP("2001:db8::1"), nil) + mock.setResponse(server2, net.ParseIP("192.0.2.2"), net.ParseIP("2001:db8::2"), nil) + mock.setResponse(server3, net.ParseIP("192.0.2.3"), net.ParseIP("2001:db8::3"), nil) + + d.DNSConfig.Servers = []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"} + + // Track starting servers for each lookup + startingServers := []string{} + + // Perform 6 lookups (2 full rotations) + for i := 0; i < 6; i++ { + host := "example" + string(rune('a'+i)) + ".com:443" + d.DNSRecords.Delete("example" + string(rune('a'+i)) + ".com") + + callLogBefore := len(mock.getCallLog()) + _, _, err := d.archiveDNS(context.Background(), host) + if err != nil { + t.Fatal(err) + } + + callLog := mock.getCallLog() + if len(callLog) > callLogBefore { + // First call in this lookup shows the starting server + startingServers = append(startingServers, callLog[callLogBefore]) + } + } + + t.Logf("Starting servers for each lookup: %v", startingServers) + + // Verify round-robin: each lookup should start from a different server + // We should see rotation across the 3 servers + if len(startingServers) < 6 { + t.Fatalf("Expected 6 starting servers, got %d", len(startingServers)) + } + + // Check that not all lookups started from the same server + uniqueStarts := make(map[string]bool) + for _, server := range startingServers { + uniqueStarts[server] = true + } + + if len(uniqueStarts) == 1 { + t.Error("All lookups started from same server - round-robin not working") + } else { + t.Logf("Round-robin working: saw %d different starting servers", len(uniqueStarts)) + } +} + +// TestDNSMultipleServersFallback tests that all servers are tried (no 4-server limit) +func TestDNSMultipleServersFallback(t *testing.T) { + d, cleanup := setupMock() + defer cleanup() + + mock := newMockDNSClient() + d.DNSClient = mock + d.dnsConcurrency = 1 + + // Configure 3 servers where first 2 fail, third succeeds + server1 := "1.1.1.1:53" + server2 := "2.2.2.2:53" + server3 := "3.3.3.3:53" + + mock.setResponse(server1, nil, nil, errors.New("server 1 timeout")) + mock.setResponse(server2, nil, nil, errors.New("server 2 timeout")) + mock.setResponse(server3, net.ParseIP("192.0.2.3"), net.ParseIP("2001:db8::3"), nil) + + d.DNSConfig.Servers = []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"} + d.DNSRecords.Delete(targetHost) + + IP, _, err := d.archiveDNS(context.Background(), target) + if err != nil { + t.Fatal(err) + } + + if IP == nil { + t.Fatal("Expected resolved IP from third server") + } + + callLog := mock.getCallLog() + t.Logf("Call log: %v", callLog) + + // Should have tried all 3 servers (proving no hardcoded 4-server limit) + uniqueServers := make(map[string]bool) + for _, call := range callLog { + uniqueServers[call] = true + } + + if len(uniqueServers) < 3 { + t.Errorf("Expected queries to all 3 servers, got %d unique servers", len(uniqueServers)) + } +} + +// TestDNSEarlyCancellation tests that queries stop once results are found +func TestDNSEarlyCancellation(t *testing.T) { + d, cleanup := setupMock() + defer cleanup() + + mock := newMockDNSClient() + d.DNSClient = mock + d.dnsConcurrency = 3 // Allow parallel queries + + // Configure 3 servers with delays + server1 := "1.1.1.1:53" + server2 := "2.2.2.2:53" + server3 := "3.3.3.3:53" + + // Server 1 is fast and succeeds + mock.setResponse(server1, net.ParseIP("192.0.2.1"), net.ParseIP("2001:db8::1"), nil) + mock.setDelay(server1, 10*time.Millisecond) + + // Servers 2 and 3 are slow + mock.setResponse(server2, net.ParseIP("192.0.2.2"), net.ParseIP("2001:db8::2"), nil) + mock.setDelay(server2, 500*time.Millisecond) + + mock.setResponse(server3, net.ParseIP("192.0.2.3"), net.ParseIP("2001:db8::3"), nil) + mock.setDelay(server3, 500*time.Millisecond) + + d.DNSConfig.Servers = []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"} + d.DNSRecords.Delete(targetHost) + + start := time.Now() + IP, _, err := d.archiveDNS(context.Background(), target) + elapsed := time.Since(start) + + if err != nil { + t.Fatal(err) + } + + if IP == nil { + t.Fatal("Expected resolved IP") + } + + // Should complete quickly due to early cancellation (not wait for slow servers) + if elapsed > 200*time.Millisecond { + t.Logf("Warning: took %v, expected <200ms with early cancellation", elapsed) + } else { + t.Logf("Early cancellation working: completed in %v", elapsed) + } + + callLog := mock.getCallLog() + t.Logf("Call log: %v", callLog) +} + +// TestDNSIPv4Only tests IPv6-disabled mode +func TestDNSIPv4Only(t *testing.T) { + d, cleanup := setupMock() + defer cleanup() + + mock := newMockDNSClient() + d.DNSClient = mock + d.disableIPv4 = false + d.disableIPv6 = true + + server1 := "1.1.1.1:53" + mock.setResponse(server1, net.ParseIP("192.0.2.1"), net.ParseIP("2001:db8::1"), nil) + + d.DNSConfig.Servers = []string{"1.1.1.1"} + d.DNSRecords.Delete(targetHost) + + IP, _, err := d.archiveDNS(context.Background(), target) + if err != nil { + t.Fatal(err) + } + + // Should return IPv4 only + if IP == nil { + t.Fatal("Expected IPv4 address") + } + + if IP.To4() == nil { + t.Errorf("Expected IPv4 address, got %v", IP) + } + + t.Logf("Resolved IPv4: %v", IP) +} + +// TestDNSMixedResults tests when IPv4 succeeds but IPv6 fails +func TestDNSMixedResults(t *testing.T) { + d, cleanup := setupMock() + defer cleanup() + + mock := newMockDNSClient() + d.DNSClient = mock + + server1 := "1.1.1.1:53" + // IPv4 succeeds, IPv6 returns no record + mock.setResponse(server1, net.ParseIP("192.0.2.1"), nil, nil) + + d.DNSConfig.Servers = []string{"1.1.1.1"} + d.DNSRecords.Delete(targetHost) + + IP, _, err := d.archiveDNS(context.Background(), target) + if err != nil { + t.Fatal(err) + } + + // Should succeed with IPv4 even though IPv6 failed + if IP == nil { + t.Fatal("Expected IPv4 address") + } + + t.Logf("Resolved with mixed results: %v", IP) +} diff --git a/go.mod b/go.mod index 6c45986..a7f9c29 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/internetarchive/gowarc -go 1.24.2 +go 1.25.4 require ( github.com/google/uuid v1.6.0 @@ -17,6 +17,7 @@ require ( go.uber.org/goleak v1.3.0 golang.org/x/net v0.46.0 golang.org/x/sync v0.17.0 + golang.org/x/sys v0.37.0 ) // By default, and historically, this project uses klauspost's gzip implementation, @@ -34,7 +35,6 @@ require ( github.com/spf13/pflag v1.0.9 // indirect golang.org/x/crypto v0.43.0 // indirect golang.org/x/mod v0.24.0 // indirect - golang.org/x/sys v0.37.0 // indirect golang.org/x/tools v0.33.0 // indirect ) diff --git a/read_test.go b/read_test.go index 15a86b0..0902559 100644 --- a/read_test.go +++ b/read_test.go @@ -291,28 +291,37 @@ func testFileEarlyEOF(t *testing.T, path string) { if err != nil { t.Fatalf("failed to open %q: %v", path, err) } + reader, err := NewReader(file) if err != nil { t.Fatalf("warc.NewReader failed for %q: %v", path, err) } + reader.cr = &countingReader{r: reader.src} // read the file into memory reader.dec, reader.compType, err = reader.cr.newDecompressionReader() + if err != nil { + t.Fatalf("failed to create decompression reader: %v", err) + } + data, err := io.ReadAll(reader.dec) if err != nil { t.Fatalf("failed to read %q: %v", path, err) } + // delete the last two bytes (\r\n) if data[len(data)-2] != '\r' || data[len(data)-1] != '\n' { t.Fatalf("expected \\r\\n, got %q", data[len(data)-2:]) } + data = data[:len(data)-2] // new reader reader, err = NewReader(io.NopCloser(bytes.NewReader(data))) if err != nil { t.Fatalf("warc.NewReader failed for %q: %v", path, err) } + // read the records for { record, err := reader.ReadRecord() @@ -329,6 +338,7 @@ func testFileEarlyEOF(t *testing.T, path string) { } record.Content.Close() } + t.Fatalf("expected `reading record boundary: EOF`, got none") } From abf8be28029cbe06dc7b998de2ef5ddc99c70878 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 13 Nov 2025 05:45:51 -0500 Subject: [PATCH 07/10] build(deps): bump the go-modules group across 1 directory with 3 updates (#158) Bumps the go-modules group with 2 updates in the / directory: [golang.org/x/net](https://github.com/golang/net) and [golang.org/x/sync](https://github.com/golang/sync). Updates `golang.org/x/net` from 0.46.0 to 0.47.0 - [Commits](https://github.com/golang/net/compare/v0.46.0...v0.47.0) Updates `golang.org/x/sync` from 0.17.0 to 0.18.0 - [Commits](https://github.com/golang/sync/compare/v0.17.0...v0.18.0) Updates `golang.org/x/sys` from 0.37.0 to 0.38.0 - [Commits](https://github.com/golang/sys/compare/v0.37.0...v0.38.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-version: 0.47.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-modules - dependency-name: golang.org/x/sync dependency-version: 0.18.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-modules - dependency-name: golang.org/x/sys dependency-version: 0.38.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-modules ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index a7f9c29..602b57a 100644 --- a/go.mod +++ b/go.mod @@ -15,9 +15,9 @@ require ( github.com/valyala/bytebufferpool v1.0.0 github.com/zeebo/blake3 v0.2.4 go.uber.org/goleak v1.3.0 - golang.org/x/net v0.46.0 - golang.org/x/sync v0.17.0 - golang.org/x/sys v0.37.0 + golang.org/x/net v0.47.0 + golang.org/x/sync v0.18.0 + golang.org/x/sys v0.38.0 ) // By default, and historically, this project uses klauspost's gzip implementation, @@ -33,7 +33,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.0.12 // indirect github.com/spf13/pflag v1.0.9 // indirect - golang.org/x/crypto v0.43.0 // indirect + golang.org/x/crypto v0.44.0 // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/tools v0.33.0 // indirect ) diff --git a/go.sum b/go.sum index d481126..3c75342 100644 --- a/go.sum +++ b/go.sum @@ -50,16 +50,16 @@ github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= +golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 70272b67e7c815d16a3142bc0b6584e974ce2332 Mon Sep 17 00:00:00 2001 From: Jake L Date: Fri, 14 Nov 2025 21:42:24 -0800 Subject: [PATCH 08/10] Add support for memory detection on Windows (#159) * feat: add windows memory I have tested locally and I'm unaware of a way to do this without `unsafe` as the native Go libraries do not support it. * fix: remove redundant tests --- .github/workflows/go.yml | 29 ++++------ pkg/spooledtempfile/memory_windows.go | 66 ++++++++++++++++++++++ pkg/spooledtempfile/memory_windows_test.go | 66 ++++++++++++++++++++++ 3 files changed, 143 insertions(+), 18 deletions(-) create mode 100644 pkg/spooledtempfile/memory_windows.go create mode 100644 pkg/spooledtempfile/memory_windows_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index f50d34b..b184e16 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -15,7 +15,7 @@ jobs: timeout-minutes: 15 strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, windows-latest] steps: - uses: actions/checkout@v5 @@ -32,37 +32,30 @@ jobs: continue-on-error: true run: go test -c -o tests && for test in $(go test -list . | grep -E "^(Test|Example)"); do ./tests -test.run "^$test\$" &>/dev/null && echo -e "$test passed\n" || echo -e "$test failed\n"; done - - name: Test (Full Suite) + - name: Test (Full Suite including spooledtempfile) if: matrix.os == 'ubuntu-latest' run: go test -race -v ./... - - name: Test (spooledtempfile only) + - name: Test (spooledtempfile only for macos) if: matrix.os == 'macos-latest' run: go test -race -v ./pkg/spooledtempfile/... + + - name: Test (spooledtempfile only for windows) + if: matrix.os == 'windows-latest' + run: go test -race -v ./pkg/spooledtempfile/... - name: Benchmarks if: matrix.os == 'ubuntu-latest' run: go test -bench=. -benchmem -run=^$ ./... - # Platform-specific test verification - - name: Test Linux-specific memory implementation - if: matrix.os == 'ubuntu-latest' - run: | - echo "Running Linux-specific memory tests..." - cd pkg/spooledtempfile - go test -v -run "TestCgroup|TestHostMeminfo|TestRead" - - - name: Test macOS-specific memory implementation - if: matrix.os == 'macos-latest' - run: | - echo "Running macOS-specific memory tests..." - cd pkg/spooledtempfile - go test -v -run "TestGetSystemMemoryUsedFraction|TestSysctlMemoryValues|TestMemoryFractionConsistency" - # Cross-compilation verification - name: Cross-compile for macOS (from Linux) if: matrix.os == 'ubuntu-latest' run: GOOS=darwin GOARCH=amd64 go build ./... + + - name: Cross-compile for windows (from Linux) + if: matrix.os == 'ubuntu-latest' + run: GOOS=windows GOARCH=amd64 go build ./... - name: Cross-compile for Linux (from macOS) if: matrix.os == 'macos-latest' diff --git a/pkg/spooledtempfile/memory_windows.go b/pkg/spooledtempfile/memory_windows.go new file mode 100644 index 0000000..ffaf7a7 --- /dev/null +++ b/pkg/spooledtempfile/memory_windows.go @@ -0,0 +1,66 @@ +//go:build windows + +package spooledtempfile + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/windows" +) + +// globalMemoryStatusEx calls the Windows API function to retrieve memory status. +// This is not currently implemented by the Golang native Windows libraries. +func globalMemoryStatusEx() (totalPhys, availPhys uint64, err error) { + kernel32 := windows.NewLazySystemDLL("kernel32.dll") + proc := kernel32.NewProc("GlobalMemoryStatusEx") + + // Define the MEMORYSTATUSEX structure matching the Windows API + // See: https://docs.microsoft.com/en-us/windows/win32/api/sysinfoapi/ns-sysinfoapi-memorystatusex + type memoryStatusEx struct { + dwLength uint32 + dwMemoryLoad uint32 + ullTotalPhys uint64 + ullAvailPhys uint64 + ullTotalPageFile uint64 + ullAvailPageFile uint64 + ullTotalVirtual uint64 + ullAvailVirtual uint64 + ullAvailExtendedVirtual uint64 + } + + var memStatus memoryStatusEx + memStatus.dwLength = 64 + + ret, _, err := proc.Call(uintptr(unsafe.Pointer(&memStatus))) + if ret == 0 { + return 0, 0, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err) + } + + return memStatus.ullTotalPhys, memStatus.ullAvailPhys, nil +} + +// getSystemMemoryUsedFraction returns the fraction of physical memory currently in use on Windows. +// It uses the GlobalMemoryStatusEx Windows API to query system memory statistics. +var getSystemMemoryUsedFraction = func() (float64, error) { + totalPhys, availPhys, err := globalMemoryStatusEx() + if err != nil { + return 0, err + } + + if totalPhys == 0 { + return 0, fmt.Errorf("total physical memory is 0") + } + + // Calculate used memory from total and available + usedPhys := totalPhys - availPhys + fraction := float64(usedPhys) / float64(totalPhys) + + // Sanity check: fraction should be between 0 and 1 + if fraction < 0 || fraction > 1 { + return 0, fmt.Errorf("calculated memory fraction out of range: %v (used: %d, total: %d)", + fraction, usedPhys, totalPhys) + } + + return fraction, nil +} diff --git a/pkg/spooledtempfile/memory_windows_test.go b/pkg/spooledtempfile/memory_windows_test.go new file mode 100644 index 0000000..6e6b31e --- /dev/null +++ b/pkg/spooledtempfile/memory_windows_test.go @@ -0,0 +1,66 @@ +//go:build windows + +package spooledtempfile + +import ( + "testing" +) + +// TestGetSystemMemoryUsedFraction verifies that the Windows implementation returns a valid memory fraction between 0 and 1. +func TestGetSystemMemoryUsedFraction(t *testing.T) { + fraction, err := getSystemMemoryUsedFraction() + if err != nil { + t.Fatalf("getSystemMemoryUsedFraction() failed: %v", err) + } + + if fraction < 0 || fraction > 1 { + t.Fatalf("memory fraction out of range: got %v, want 0.0-1.0", fraction) + } + + t.Logf("Current system memory usage: %.2f%%", fraction*100) +} + +// TestGlobalMemoryStatusEx verifies that we can successfully retrieve memory values via globalMemoryStatusEx. +func TestGlobalMemoryStatusEx(t *testing.T) { + totalPhys, availPhys, err := globalMemoryStatusEx() + if err != nil { + t.Fatalf("globalMemoryStatusEx failed: %v", err) + } + + if totalPhys == 0 { + t.Fatal("total physical memory is 0") + } + + t.Logf("Total physical memory: %d bytes (%.2f GB)", totalPhys, float64(totalPhys)/(1024*1024*1024)) + t.Logf("Available physical memory: %d bytes (%.2f GB)", availPhys, float64(availPhys)/(1024*1024*1024)) + + usedPhys := totalPhys - availPhys + t.Logf("Used physical memory: %d bytes (%.2f GB)", usedPhys, float64(usedPhys)/(1024*1024*1024)) + + usedPercent := float64(usedPhys) / float64(totalPhys) * 100 + t.Logf("Memory usage: %.2f%%", usedPercent) +} + +// TestMemoryFractionConsistency verifies that multiple calls return consistent values. +func TestMemoryFractionConsistency(t *testing.T) { + const calls = 5 + var fractions [calls]float64 + + for i := 0; i < calls; i++ { + frac, err := getSystemMemoryUsedFraction() + if err != nil { + t.Fatalf("call %d failed: %v", i, err) + } + fractions[i] = frac + } + + // Check that all values are within a reasonable range of each other + // Memory usage shouldn't vary wildly between consecutive calls + for i := 1; i < calls; i++ { + diff := fractions[i] - fractions[i-1] + if diff < -0.2 || diff > 0.2 { + t.Errorf("memory fraction changed too much between calls: %v -> %v (diff: %v)", + fractions[i-1], fractions[i], diff) + } + } +} From 5fe2448f5f053dc456fff2e45d6c76c5d80573c8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 18 Oct 2025 17:57:52 +0000 Subject: [PATCH 09/10] Initial plan From 4474dcb2811505e37e0ff200aa16fdbcdd5821a9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 15 Nov 2025 06:11:31 +0000 Subject: [PATCH 10/10] Use testFileSingleHashCheck in POST tests and rebase on master Co-authored-by: NGTmeaty <2244519+NGTmeaty@users.noreply.github.com> --- client_test.go | 15 +++------------ smoke_test.go | 10 +++++----- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/client_test.go b/client_test.go index 38bda12..5c4ef22 100644 --- a/client_test.go +++ b/client_test.go @@ -1915,7 +1915,7 @@ func TestHTTPClientPOSTWithTextPayload(t *testing.T) { // Check the WARC records contain the POST request and response for _, path := range files { - testFileHash(t, path) + testFileSingleHashCheck(t, path, "sha1:RFV2ZU2BHITF3PW7BSPBQE65GFZS7F5G", []string{"154"}, 1, server.URL+"/") file, err := os.Open(path) if err != nil { @@ -1929,7 +1929,6 @@ func TestHTTPClientPOSTWithTextPayload(t *testing.T) { } foundRequest := false - foundResponse := false for { record, err := reader.ReadRecord() @@ -1958,20 +1957,12 @@ func TestHTTPClientPOSTWithTextPayload(t *testing.T) { } } - // Check for response record - if record.Header.Get("WARC-Type") == "response" { - foundResponse = true - } - record.Content.Close() } if !foundRequest { t.Error("No request record found in WARC file") } - if !foundResponse { - t.Error("No response record found in WARC file") - } } } @@ -2047,7 +2038,7 @@ func TestHTTPClientPOSTWithJSONPayload(t *testing.T) { // Check the WARC records contain the POST request with JSON body for _, path := range files { - testFileHash(t, path) + testFileSingleHashCheck(t, path, "sha1:IAKLOIOTQX2W7PAAWWA2TELLU5HCKO3V", []string{"191"}, 1, server.URL+"/") file, err := os.Open(path) if err != nil { @@ -2168,7 +2159,7 @@ func TestHTTPClientPOSTWithFormData(t *testing.T) { // Check the WARC records contain the POST request with form data for _, path := range files { - testFileHash(t, path) + testFileSingleHashCheck(t, path, "sha1:DGXE2J6TLUT3GYLTA2LNA4NQMMPF5SWX", []string{"175"}, 1, server.URL+"/") file, err := os.Open(path) if err != nil { diff --git a/smoke_test.go b/smoke_test.go index 6411001..bb92844 100644 --- a/smoke_test.go +++ b/smoke_test.go @@ -30,11 +30,11 @@ func TestSmokeWARCFormatRegression(t *testing.T) { // These values were extracted from a known-good WARC file and serve as // a snapshot of correct format behavior. expectedRecords := []struct { - warcType string - contentLength int64 - blockDigest string - payloadDigest string // only for response records - targetURI string // only for response records + warcType string + contentLength int64 + blockDigest string + payloadDigest string // only for response records + targetURI string // only for response records }{ { warcType: "warcinfo",