diff --git a/src/utils.js b/src/utils.js index 7148209..8d6fdba 100644 --- a/src/utils.js +++ b/src/utils.js @@ -53,8 +53,57 @@ export function equals(a, b) { : typeof a === 'object' && Object.keys(a).length === Object.keys(b).length && Object.keys(a).every(k => equals(a[k], b[k])) } +/** + * Get the byte length using fetch with a ranged GET request. + * Aborts the request if server returns 200 instead of 206. + * + * @param {string} url + * @param {RequestInit} [requestInit] fetch options + * @param {typeof globalThis.fetch} [fetchFn] fetch function to use + * @returns {Promise} + */ +async function byteLengthFromUrlUsingFetch(url, requestInit = {}, fetchFn = globalThis.fetch) { + const controller = new AbortController() + const headers = new Headers(requestInit.headers) + headers.set('Range', 'bytes=0-0') + + const res = await fetchFn(url, { + ...requestInit, + headers, + signal: controller.signal, + }) + + if (!res.ok) throw new Error(`fetch with range failed ${res.status}`) + + // Server supports Range requests (206 Partial Content) + if (res.status === 206) { + const contentRange = res.headers.get('Content-Range') + if (!contentRange) throw new Error('missing content-range header') + + // Parse "bytes 0-0/9446073" to get total length + const match = contentRange.match(/bytes \d+-\d+\/(\d+)/) + if (!match) throw new Error(`invalid content-range header: ${contentRange}`) + + return parseInt(match[1]) + } + + // Server ignored Range and returned 200 - get Content-Length and abort request + if (res.status === 200) { + const contentLength = res.headers.get('Content-Length') + + // Abort the request to stop any ongoing download + controller.abort() + + if (contentLength) return parseInt(contentLength) + } + + throw new Error('server does not support range requests and missing content-length') +} + /** * Get the byte length of a URL using a HEAD request. + * If HEAD fails with 403 (e.g., with signed S3 URLs), falls back to a ranged GET request. + * If HEAD succeeds but Content-Length is missing, falls back to GET with range. * If requestInit is provided, it will be passed to fetch. * * @param {string} url @@ -64,13 +113,20 @@ export function equals(a, b) { */ export async function byteLengthFromUrl(url, requestInit, customFetch) { const fetch = customFetch ?? globalThis.fetch - return await fetch(url, { ...requestInit, method: 'HEAD' }) - .then(res => { - if (!res.ok) throw new Error(`fetch head failed ${res.status}`) - const length = res.headers.get('Content-Length') - if (!length) throw new Error('missing content length') - return parseInt(length) - }) + const res = await fetch(url, { ...requestInit, method: 'HEAD' }) + + // If HEAD request is forbidden (common with signed S3 URLs), try GET with range + if (res.status === 403) { + return byteLengthFromUrlUsingFetch(url, requestInit, fetch) + } + + if (!res.ok) throw new Error(`fetch head failed ${res.status}`) + const length = res.headers.get('Content-Length') + // If Content-Length is missing from HEAD, fallback to GET with range + if (!length) { + return byteLengthFromUrlUsingFetch(url, requestInit, fetch) + } + return parseInt(length) } /** diff --git a/test/utils.test.js b/test/utils.test.js index 3a52c83..4e813d6 100644 --- a/test/utils.test.js +++ b/test/utils.test.js @@ -56,13 +56,21 @@ describe('byteLengthFromUrl', () => { await expect(byteLengthFromUrl('https://example.com')).rejects.toThrow('fetch head failed 404') }) - it('throws an error if Content-Length header is missing', async () => { - global.fetch = vi.fn().mockResolvedValueOnce({ - ok: true, - headers: new Map(), - }) + it('falls back to GET with range if Content-Length header is missing from HEAD', async () => { + const customFetch = vi.fn() + .mockResolvedValueOnce({ + ok: true, + headers: new Map(), + }) + .mockResolvedValueOnce({ + ok: true, + status: 206, + headers: new Map([['Content-Range', 'bytes 0-0/2048']]), + }) - await expect(byteLengthFromUrl('https://example.com')).rejects.toThrow('missing content length') + const result = await byteLengthFromUrl('https://example.com', undefined, customFetch) + expect(result).toBe(2048) + expect(customFetch).toHaveBeenCalledTimes(2) }) @@ -95,6 +103,137 @@ describe('byteLengthFromUrl', () => { expect(result).toBe(2048) expect(customFetch).toHaveBeenCalledWith('https://example.com', { ...requestInit, method: 'HEAD' }) }) + + it('falls back to ranged GET when HEAD returns 403', async () => { + const customFetch = vi.fn() + .mockResolvedValueOnce({ ok: false, status: 403 }) + .mockResolvedValueOnce({ + ok: true, + status: 206, + headers: new Map([['Content-Range', 'bytes 0-0/9446073']]), + }) + + const result = await byteLengthFromUrl('https://example.com', undefined, customFetch) + expect(result).toBe(9446073) + expect(customFetch).toHaveBeenCalledTimes(2) + expect(customFetch).toHaveBeenNthCalledWith(1, 'https://example.com', { method: 'HEAD' }) + }) + + it('fallback throws error if Content-Range header is missing on 206 response', async () => { + const customFetch = vi.fn() + .mockResolvedValueOnce({ ok: false, status: 403 }) + .mockResolvedValueOnce({ + ok: true, + status: 206, + headers: new Map(), + }) + + await expect(byteLengthFromUrl('https://example.com', undefined, customFetch)).rejects.toThrow('missing content-range header') + }) + + it('fallback throws error if Content-Range header is invalid', async () => { + const customFetch = vi.fn() + .mockResolvedValueOnce({ ok: false, status: 403 }) + .mockResolvedValueOnce({ + ok: true, + status: 206, + headers: new Map([['Content-Range', 'invalid format']]), + }) + + await expect(byteLengthFromUrl('https://example.com', undefined, customFetch)).rejects.toThrow('invalid content-range header') + }) + + it('fallback uses Content-Length when server returns 200 (Range not supported)', async () => { + const mockArrayBuffer = new ArrayBuffer(5242880) + + const customFetch = vi.fn() + .mockResolvedValueOnce({ ok: false, status: 403 }) + .mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Map([['Content-Length', '5242880']]), + arrayBuffer: () => Promise.resolve(mockArrayBuffer), + }) + + const result = await byteLengthFromUrl('https://example.com', undefined, customFetch) + expect(result).toBe(5242880) + }) + + it('fallback throws error when server returns 200 without Content-Length', async () => { + const customFetch = vi.fn() + .mockResolvedValueOnce({ ok: false, status: 403 }) + .mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Map(), + body: null, + }) + + await expect(byteLengthFromUrl('https://example.com', undefined, customFetch)).rejects.toThrow( + 'server does not support range requests and missing content-length' + ) + }) + + describe('fetch with AbortController', () => { + it('aborts request when server returns 200 with Content-Length', async () => { + let capturedSignal = null + + const customFetch = vi.fn() + .mockResolvedValueOnce({ ok: false, status: 403 }) // HEAD fails + .mockImplementation((url, options) => { // GET returns 200 + capturedSignal = options.signal + return Promise.resolve({ + ok: true, + status: 200, + headers: new Map([['Content-Length', '5242880']]), + }) + }) + + const result = await byteLengthFromUrl('https://example.com', undefined, customFetch) + expect(result).toBe(5242880) + expect(capturedSignal).toBeDefined() + // @ts-ignore - capturedSignal is assigned in the mock + expect(capturedSignal.aborted).toBe(true) + }) + + it('does not abort when server returns 206', async () => { + let capturedSignal = null + + const customFetch = vi.fn() + .mockResolvedValueOnce({ ok: false, status: 403 }) // HEAD fails + .mockImplementation((url, options) => { // GET returns 206 + capturedSignal = options.signal + return Promise.resolve({ + ok: true, + status: 206, + headers: new Map([['Content-Range', 'bytes 0-0/9446073']]), + }) + }) + + const result = await byteLengthFromUrl('https://example.com', undefined, customFetch) + expect(result).toBe(9446073) + expect(capturedSignal).toBeDefined() + // @ts-ignore - capturedSignal is assigned in the mock + expect(capturedSignal.aborted).toBe(false) + }) + + it('passes abort signal to fetch', async () => { + const customFetch = vi.fn() + .mockResolvedValueOnce({ ok: false, status: 403 }) // HEAD fails + .mockResolvedValueOnce({ // GET returns 206 + ok: true, + status: 206, + headers: new Map([['Content-Range', 'bytes 0-0/1024']]), + }) + + await byteLengthFromUrl('https://example.com', undefined, customFetch) + + // Check second call (the GET with range) + const secondCallArgs = customFetch.mock.calls[1] + expect(secondCallArgs[1]).toHaveProperty('signal') + expect(secondCallArgs[1].signal).toBeInstanceOf(AbortSignal) + }) + }) }) describe('asyncBufferFromUrl', () => {