diff --git a/src/utils.js b/src/utils.js index 9a11350..f57a59d 100644 --- a/src/utils.js +++ b/src/utils.js @@ -59,9 +59,11 @@ export function equals(a, b) { * * @param {string} url * @param {RequestInit} [requestInit] fetch options + * @param {typeof globalThis.fetch} [customFetch] fetch function to use * @returns {Promise} */ -export async function byteLengthFromUrl(url, requestInit) { +export async function byteLengthFromUrl(url, requestInit, customFetch) { + const fetch = customFetch ?? globalThis.fetch return await fetch(url, { ...requestInit, method: 'HEAD' }) .then(res => { if (!res.ok) throw new Error(`fetch head failed ${res.status}`) @@ -74,18 +76,21 @@ export async function byteLengthFromUrl(url, requestInit) { /** * Construct an AsyncBuffer for a URL. * If byteLength is not provided, will make a HEAD request to get the file size. + * If fetch is provided, it will be used instead of the global fetch. * If requestInit is provided, it will be passed to fetch. * * @param {object} options * @param {string} options.url * @param {number} [options.byteLength] + * @param {typeof globalThis.fetch} [options.fetch] fetch function to use * @param {RequestInit} [options.requestInit] * @returns {Promise} */ -export async function asyncBufferFromUrl({ url, byteLength, requestInit }) { +export async function asyncBufferFromUrl({ url, byteLength, requestInit, fetch: customFetch }) { if (!url) throw new Error('missing url') + const fetch = customFetch ?? globalThis.fetch // byte length from HEAD request - byteLength ||= await byteLengthFromUrl(url, requestInit) + byteLength ||= await byteLengthFromUrl(url, requestInit, fetch) /** * A promise for the whole buffer, if range requests are not supported. diff --git a/test/utils.test.js b/test/utils.test.js index 13577ef..3a52c83 100644 --- a/test/utils.test.js +++ b/test/utils.test.js @@ -83,6 +83,18 @@ describe('byteLengthFromUrl', () => { await expect(byteLengthFromUrl('https://example.com')).rejects.toThrow('fetch head failed 401') }) + + it ('uses the provided fetch function, along with requestInit if passed', async () => { + const customFetch = vi.fn().mockResolvedValueOnce({ + ok: true, + headers: new Map([['Content-Length', '2048']]), + }) + + const requestInit = { headers: { authorization: 'Bearer token' } } + const result = await byteLengthFromUrl('https://example.com', requestInit, customFetch) + expect(result).toBe(2048) + expect(customFetch).toHaveBeenCalledWith('https://example.com', { ...requestInit, method: 'HEAD' }) + }) }) describe('asyncBufferFromUrl', () => { @@ -226,4 +238,43 @@ describe('asyncBufferFromUrl', () => { expect(fetch).toBeCalledTimes(1) }) }) + + describe('when a custom fetch function is provided', () => { + it ('is used to get the byte length', async () => { + const customFetch = vi.fn().mockResolvedValueOnce({ + ok: true, + headers: new Map([['Content-Length', '2048']]), + }) + + const requestInit = { headers: { authorization: 'Bearer token' } } + const buffer = await asyncBufferFromUrl({ url: 'https://example.com', requestInit, fetch: customFetch }) + expect(buffer.byteLength).toBe(2048) + expect(customFetch).toHaveBeenCalledWith('https://example.com', { ...requestInit, method: 'HEAD' }) + }) + it ('is used to fetch the slice', async () => { + const mockArrayBuffer = new ArrayBuffer(35) + let counter = 0 + function rateLimitedFetch() { + counter++ + if (counter === 2) { + return Promise.resolve({ ok: true, status: 206, body: {}, arrayBuffer: () => Promise.resolve(mockArrayBuffer) }) + } + return Promise.resolve({ ok: false, status: 429 }) + } + const customFetch = vi.fn().mockImplementation(async () => { + while (true) { + const result = await rateLimitedFetch() + if (result.ok) { + return result + } + await new Promise(resolve => setTimeout(resolve, 100)) // wait for 100ms before retrying + } + }) + const requestInit = { headers: { authorization: 'Bearer token' } } + const buffer = await asyncBufferFromUrl({ url: 'https://example.com', byteLength: 1024, requestInit, fetch: customFetch }) + const result = await buffer.slice(50, 85) + expect(result).toBe(mockArrayBuffer) + expect(customFetch).toHaveBeenCalledWith('https://example.com', { headers: new Headers({ ...requestInit.headers, Range: 'bytes=50-84' }) }) + }) + }) })