diff --git a/.eslintrc.json b/.eslintrc.json index 0e46490..766a59b 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -7,6 +7,7 @@ "plugins": ["import", "jsdoc"], "rules": { "@typescript-eslint/no-explicit-any": "warn", + "@typescript-eslint/no-unused-vars": "warn", "arrow-spacing": "error", "camelcase": "error", "comma-spacing": "error", @@ -30,6 +31,7 @@ "jsdoc/check-tag-names": "error", "jsdoc/no-types": "error", "jsdoc/sort-tags": "error", + "no-constant-condition": "warn", "no-multi-spaces": "error", "no-trailing-spaces": "error", "no-var": "error", diff --git a/src/thrift.ts b/src/thrift.ts new file mode 100644 index 0000000..3a0d25f --- /dev/null +++ b/src/thrift.ts @@ -0,0 +1,228 @@ +// TCompactProtocol types +const CompactType = { + STOP: 0, + TRUE: 1, + FALSE: 2, + BYTE: 3, + I16: 4, + I32: 5, + I64: 6, + DOUBLE: 7, + BINARY: 8, + LIST: 9, + SET: 10, + MAP: 11, + STRUCT: 12, + UUID: 13, +} + +/** + * Parse TCompactProtocol + */ +export function deserializeTCompactProtocol(buffer: ArrayBuffer): [number, Record] { + const view = new DataView(buffer) + let index = 0 + let lastFid = 0 + const result: Record = {} + + while (index < buffer.byteLength) { + // Parse each field based on its type and add to the result object + const [type, fid, newIndex, newLastFid] = readFieldBegin(view, index, lastFid) + index = newIndex + lastFid = newLastFid + + if (type === CompactType.STOP) { + break + } + + // Handle the field based on its type + let fieldValue + [fieldValue, index] = readElement(view, type, index) + result[`field_${fid}`] = fieldValue + } + + return [ index, result ] +} + +/** + * Read a single element based on its type + * @returns [value, newIndex] + */ +function readElement(view: DataView, type: number, index: number): [any, number] { + switch (type) { + case CompactType.TRUE: + return [true, index] + case CompactType.FALSE: + return [false, index] + case CompactType.BYTE: + // read byte directly + return [view.getInt8(index), index + 1] + case CompactType.I16: + case CompactType.I32: + return readZigZag(view, index) + case CompactType.I64: + return readZigZagBigInt(view, index) + case CompactType.DOUBLE: + return [view.getFloat64(index, true), index + 8] + case CompactType.BINARY: { + // strings are encoded as utf-8, no \0 delimiter + const [stringLength, stringIndex] = readVarInt(view, index) + const strBytes = new Uint8Array(view.buffer, stringIndex, stringLength) + return [new TextDecoder().decode(strBytes), stringIndex + stringLength] + } + case CompactType.LIST: { + const [elemType, listSize, listIndex] = readCollectionBegin(view, index) + index = listIndex + const listValues = [] + for (let i = 0; i < listSize; i++) { + let listElem + [listElem, index] = readElement(view, elemType, index) + listValues.push(listElem) + } + return [listValues, index] + } + case CompactType.STRUCT: { + const structValues: {[key: string]: any} = {} + let structLastFid = 0 + while (true) { + let structFieldType, structFid, structIndex + [structFieldType, structFid, structIndex, structLastFid] = readFieldBegin(view, index, structLastFid) + index = structIndex + if (structFieldType === CompactType.STOP) { + break + } + let structFieldValue + [structFieldValue, index] = readElement(view, structFieldType, index) + structValues[`field_${structFid}`] = structFieldValue + } + return [structValues, index] + } + // TODO: MAP and SET + case CompactType.UUID: { + // Read 16 bytes to uuid string + let uuid = '' + for (let i = 0; i < 16; i++) { + uuid += view.getUint8(index++).toString(16).padStart(2, '0') + } + return [uuid, index] + } + default: + throw new Error(`Unhandled type: ${type}`) + } +} + +/** + * Var int, also known as Unsigned LEB128. + * Var ints take 1 to 5 bytes (int32) or 1 to 10 bytes (int64). + * Takes a Big Endian unsigned integer, left-pads the bit-string to make it a + * multiple of 7 bits, splits it into 7-bit groups, prefix the most-significant + * 7-bit group with the 0 bit, prefixing the remaining 7-bit groups with the + * 1 bit and encode the resulting bit-string as Little Endian. + */ +function readVarInt(view: DataView, index: number): [number, number] { + let result = 0 + let shift = 0 + while (true) { + const byte = view.getUint8(index++) + result |= (byte & 0x7f) << shift + if ((byte & 0x80) === 0) { + return [result, index] + } + shift += 7 + } +} + +/** + * Read a varint as a bigint. + */ +function readVarBigInt(view: DataView, index: number): [bigint, number] { + let result = BigInt(0) + let shift = BigInt(0) + while (true) { + const byte = BigInt(view.getUint8(index++)) + result |= (byte & BigInt(0x7f)) << shift + if ((byte & BigInt(0x80)) === BigInt(0)) { + return [result, index] + } + shift += BigInt(7) + } +} + +/** + * Values of type int32 and int64 are transformed to a zigzag int. + * A zigzag int folds positive and negative numbers into the positive number space. + */ +function readZigZag(view: DataView, index: number): [number, number] { + const [zigzag, newIndex] = readVarInt(view, index) + // convert zigzag to int + const value = (zigzag >>> 1) ^ -(zigzag & 1) + return [value, newIndex] +} + +/** + * A zigzag int folds positive and negative numbers into the positive number space. + * This version returns a BigInt. + */ +function readZigZagBigInt(view: DataView, index: number): [bigint, number] { + const [zigzag, newIndex] = readVarBigInt(view, index) + // convert zigzag to int + const value = (zigzag >> BigInt(1)) ^ -(zigzag & BigInt(1)) + return [value, newIndex] +} + +/** + * Get thrift type from half a byte + */ +function getCompactType(byte: number): number { + return byte & 0x0f +} + +/** + * Read field type and field id + */ +function readFieldBegin(view: DataView, index: number, lastFid: number): [number, number, number, number] { + const type = view.getUint8(index++) + if ((type & 0x0f) === CompactType.STOP) { + // STOP also ends a struct + return [0, 0, index, lastFid] + } + const delta = type >> 4 + let fid // field id + if (delta === 0) { + // not a delta, read zigzag varint field id + [fid, index] = readZigZag(view, index) + } else { + // add delta to last field id + fid = lastFid + delta + } + return [getCompactType(type), fid, index, fid] +} + +function readCollectionBegin(view: DataView, index: number): [number, number, number] { + const sizeType = view.getUint8(index++) + const size = sizeType >> 4 + const type = getCompactType(sizeType) + if (size === 15) { + const [newSize, newIndex] = readVarInt(view, index) + return [type, newSize, newIndex] + } + return [type, size, index] +} + +/** + * Convert int to varint. Outputs 1-5 bytes for int32. + */ +export function toVarInt(n: number): number[] { + let idx = 0 + const varInt = [] + while (true) { + if ((n & ~0x7f) === 0) { + varInt[idx++] = n + break + } else { + varInt[idx++] = (n & 0x7f) | 0x80 + n >>>= 7 + } + } + return varInt +} diff --git a/test/thrift.test.ts b/test/thrift.test.ts new file mode 100644 index 0000000..80b0333 --- /dev/null +++ b/test/thrift.test.ts @@ -0,0 +1,80 @@ +import { describe, expect, it } from 'vitest' +import { deserializeTCompactProtocol, toVarInt } from '../src/thrift' + +describe('deserializeTCompactProtocol function', () => { + + it('parses basic types correctly', () => { + // Setup a buffer with thrift encoded data for basic types + const buffer = new ArrayBuffer(128) + const view = new DataView(buffer) + let index = 0 + + // Boolean + view.setUint8(index++, 0x11) // Field 1 type TRUE + view.setUint8(index++, 0x12) // Field 2 type FALSE + + // Byte + view.setUint8(index++, 0x13) // Field 3 type BYTE + view.setUint8(index++, 0x7f) // Max value for a signed byte + + // Int16 + view.setUint8(index++, 0x14) // Field 4 type int16 + view.setUint8(index++, 0xfe) // 0xfffe zigzag => 16-bit max value 0x7fff + view.setUint8(index++, 0xff) + view.setUint8(index++, 0x3) + + // Int32 + view.setUint8(index++, 0x15) // Field 5 type int32 + view.setUint8(index++, 0xfe) // 0xfffffffe zigzag => 32-bit max value 0x7fffffff + view.setUint8(index++, 0xff) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0x0f) + + // Int64 + view.setUint8(index++, 0x16) // Field 6 type int64 + view.setUint8(index++, 0xfe) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0xff) + view.setUint8(index++, 0x01) + + // Double + view.setUint8(index++, 0x17) // Field 7 type DOUBLE + view.setFloat64(index, 123.456, true) + index += 8 + + // String + const str = 'Hello, Thrift!' + view.setUint8(index++, 0x18) // Field 8 type STRING + // write string length as varint + const stringLengthVarInt = toVarInt(str.length) + stringLengthVarInt.forEach(byte => view.setUint8(index++, byte)) + // write string bytes + for (let i = 0; i < str.length; i++) { + view.setUint8(index++, str.charCodeAt(i)) + } + + // Mark the end of the structure + view.setUint8(index, 0x00) // STOP field + + const [bufferLength, result] = deserializeTCompactProtocol(buffer) + expect(bufferLength).toBe(index + 1) + + // Assertions for each basic type + expect(result.field_1).toBe(true) // TRUE + expect(result.field_2).toBe(false) // FALSE + expect(result.field_3).toBe(0x7f) // BYTE + expect(result.field_4).toBe(0x7fff) // I16 + expect(result.field_5).toBe(0x7fffffff) // I32 + expect(result.field_6).toBe(BigInt('0x7fffffffffffffff')) // I64 + expect(result.field_7).toBeCloseTo(123.456) // DOUBLE + expect(result.field_8).toBe('Hello, Thrift!') // STRING + }) + +})