forked from sheetjs/docs.sheetjs.com
		
	
		
			
	
	
		
			93 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			TypeScript
		
	
	
	
	
	
		
		
			
		
	
	
			93 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			TypeScript
		
	
	
	
	
	
|  | import { useState, useCallback } from "kaioken"; | ||
|  | import { TensorContainerObject, data, layers, linspace, train, sequential } from "@tensorflow/tfjs"; | ||
|  | import { read, utils } from "xlsx"; | ||
|  | 
 | ||
|  | import type { Tensor, Rank } from "@tensorflow/tfjs"; | ||
|  | import type { WorkSheet } from "xlsx"; | ||
|  | 
 | ||
|  | interface Data extends TensorContainerObject { | ||
|  |   xs: Tensor; | ||
|  |   ys: Tensor; | ||
|  | } | ||
|  | type DSet = data.Dataset<Data>; | ||
|  | 
 | ||
|  | export default function SheetJSToTFJSCSV() { | ||
|  |   const [output, setOutput] = useState(""); | ||
|  |   const [results, setResults] = useState<[number, number][]>([]); | ||
|  |   const [disabled, setDisabled] = useState(false); | ||
|  | 
 | ||
|  |   function worksheet_to_csv_url(worksheet: WorkSheet) { | ||
|  |     /* generate CSV */ | ||
|  |     const csv = utils.sheet_to_csv(worksheet); | ||
|  | 
 | ||
|  |     /* CSV -> Uint8Array -> Blob */ | ||
|  |     const u8 = new TextEncoder().encode(csv); | ||
|  |     const blob = new Blob([u8], { type: "text/csv" }); | ||
|  | 
 | ||
|  |     /* generate a blob URL */ | ||
|  |     return URL.createObjectURL(blob); | ||
|  |   } | ||
|  | 
 | ||
|  |   const doit = useCallback(async () => { | ||
|  |     setResults([]); setOutput(""); setDisabled(true); | ||
|  |     try { | ||
|  |       /* fetch file */ | ||
|  |       const f = await fetch("https://docs.sheetjs.com/cd.xls"); | ||
|  |       const ab = await f.arrayBuffer(); | ||
|  |       /* parse file and get first worksheet */ | ||
|  |       const wb = read(ab); | ||
|  |       const ws = wb.Sheets[wb.SheetNames[0]]; | ||
|  | 
 | ||
|  |       /* generate blob URL */ | ||
|  |       const url = worksheet_to_csv_url(ws); | ||
|  | 
 | ||
|  |       /* feed to tf.js */ | ||
|  |       const dataset = data.csv(url, { | ||
|  |         hasHeader: true, | ||
|  |         configuredColumnsOnly: true, | ||
|  |         columnConfigs:{ | ||
|  |           "Horsepower": {required: false, default: 0}, | ||
|  |           "Miles_per_Gallon":{required: false, default: 0, isLabel:true} | ||
|  |         } | ||
|  |       }); | ||
|  | 
 | ||
|  |       /* pre-process data */ | ||
|  |       let flat = (dataset as unknown as DSet) | ||
|  |         .map(({xs,ys}) =>({xs: Object.values(xs), ys: Object.values(ys)})) | ||
|  |         .filter(({xs,ys}) => [...xs,...ys].every(v => v>0)); | ||
|  | 
 | ||
|  |       /* normalize manually :( */ | ||
|  |       let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity; | ||
|  |       await flat.forEachAsync(({xs, ys}) => { | ||
|  |         minX = Math.min(minX, xs[0]); maxX = Math.max(maxX, xs[0]); | ||
|  |         minY = Math.min(minY, ys[0]); maxY = Math.max(maxY, ys[0]); | ||
|  |       }); | ||
|  |       flat = flat.map(({xs, ys}) => ({xs:xs.map(v => (v-minX)/(maxX - minX)),ys:ys.map(v => (v-minY)/(maxY-minY))})); | ||
|  |       let batch = flat.batch(32); | ||
|  | 
 | ||
|  |       /* build and train model */ | ||
|  |       const model = sequential(); | ||
|  |       model.add(layers.dense({inputShape: [1], units: 1})); | ||
|  |       model.compile({ optimizer: train.sgd(0.000001), loss: 'meanSquaredError' }); | ||
|  |       await model.fitDataset(batch, { epochs: 100, callbacks: { onEpochEnd: async (epoch, logs) => { | ||
|  |         setOutput(`${epoch}:${logs?.loss}`); | ||
|  |       }}}); | ||
|  | 
 | ||
|  |       /* predict values */ | ||
|  |       const inp = linspace(0, 1, 9); | ||
|  |       const pred = model.predict(inp) as Tensor<Rank>; | ||
|  |       const xs = await inp.dataSync(), ys = await pred.dataSync(); | ||
|  |       setResults(Array.from(xs).map((x, i) => [ x * (maxX - minX) + minX, ys[i] * (maxY - minY) + minY ])); | ||
|  |       setOutput(""); | ||
|  |     } catch(e) { setOutput(`ERROR: ${String(e)}`); } finally { setDisabled(false);} | ||
|  |   }, []); | ||
|  | 
 | ||
|  |   return ( <> | ||
|  |     <button onclick={doit} disabled={disabled}>Click to run</button><br/> | ||
|  |     {output && <pre>{output}</pre> || <></>} | ||
|  |     {results.length && <table><thead><tr><th>Horsepower</th><th>MPG</th></tr></thead><tbody> | ||
|  |     {results.map((r,i) => <tr key={i}><td>{r[0]}</td><td>{r[1].toFixed(2)}</td></tr>)} | ||
|  |     </tbody></table> || <></>} | ||
|  |   </> ); | ||
|  | } |