forked from sheetjs/docs.sheetjs.com
		
	
		
			
	
	
		
			93 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			TypeScript
		
	
	
	
	
	
		
		
			
		
	
	
			93 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			TypeScript
		
	
	
	
	
	
| 
								 | 
							
								import { useState, useCallback } from "kaioken";
							 | 
						||
| 
								 | 
							
								import { TensorContainerObject, data, layers, linspace, train, sequential } from "@tensorflow/tfjs";
							 | 
						||
| 
								 | 
							
								import { read, utils } from "xlsx";
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import type { Tensor, Rank } from "@tensorflow/tfjs";
							 | 
						||
| 
								 | 
							
								import type { WorkSheet } from "xlsx";
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								interface Data extends TensorContainerObject {
							 | 
						||
| 
								 | 
							
								  xs: Tensor;
							 | 
						||
| 
								 | 
							
								  ys: Tensor;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								type DSet = data.Dataset<Data>;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								export default function SheetJSToTFJSCSV() {
							 | 
						||
| 
								 | 
							
								  const [output, setOutput] = useState("");
							 | 
						||
| 
								 | 
							
								  const [results, setResults] = useState<[number, number][]>([]);
							 | 
						||
| 
								 | 
							
								  const [disabled, setDisabled] = useState(false);
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  function worksheet_to_csv_url(worksheet: WorkSheet) {
							 | 
						||
| 
								 | 
							
								    /* generate CSV */
							 | 
						||
| 
								 | 
							
								    const csv = utils.sheet_to_csv(worksheet);
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    /* CSV -> Uint8Array -> Blob */
							 | 
						||
| 
								 | 
							
								    const u8 = new TextEncoder().encode(csv);
							 | 
						||
| 
								 | 
							
								    const blob = new Blob([u8], { type: "text/csv" });
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    /* generate a blob URL */
							 | 
						||
| 
								 | 
							
								    return URL.createObjectURL(blob);
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  const doit = useCallback(async () => {
							 | 
						||
| 
								 | 
							
								    setResults([]); setOutput(""); setDisabled(true);
							 | 
						||
| 
								 | 
							
								    try {
							 | 
						||
| 
								 | 
							
								      /* fetch file */
							 | 
						||
| 
								 | 
							
								      const f = await fetch("https://docs.sheetjs.com/cd.xls");
							 | 
						||
| 
								 | 
							
								      const ab = await f.arrayBuffer();
							 | 
						||
| 
								 | 
							
								      /* parse file and get first worksheet */
							 | 
						||
| 
								 | 
							
								      const wb = read(ab);
							 | 
						||
| 
								 | 
							
								      const ws = wb.Sheets[wb.SheetNames[0]];
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      /* generate blob URL */
							 | 
						||
| 
								 | 
							
								      const url = worksheet_to_csv_url(ws);
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      /* feed to tf.js */
							 | 
						||
| 
								 | 
							
								      const dataset = data.csv(url, {
							 | 
						||
| 
								 | 
							
								        hasHeader: true,
							 | 
						||
| 
								 | 
							
								        configuredColumnsOnly: true,
							 | 
						||
| 
								 | 
							
								        columnConfigs:{
							 | 
						||
| 
								 | 
							
								          "Horsepower": {required: false, default: 0},
							 | 
						||
| 
								 | 
							
								          "Miles_per_Gallon":{required: false, default: 0, isLabel:true}
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								      });
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      /* pre-process data */
							 | 
						||
| 
								 | 
							
								      let flat = (dataset as unknown as DSet)
							 | 
						||
| 
								 | 
							
								        .map(({xs,ys}) =>({xs: Object.values(xs), ys: Object.values(ys)}))
							 | 
						||
| 
								 | 
							
								        .filter(({xs,ys}) => [...xs,...ys].every(v => v>0));
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      /* normalize manually :( */
							 | 
						||
| 
								 | 
							
								      let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity;
							 | 
						||
| 
								 | 
							
								      await flat.forEachAsync(({xs, ys}) => {
							 | 
						||
| 
								 | 
							
								        minX = Math.min(minX, xs[0]); maxX = Math.max(maxX, xs[0]);
							 | 
						||
| 
								 | 
							
								        minY = Math.min(minY, ys[0]); maxY = Math.max(maxY, ys[0]);
							 | 
						||
| 
								 | 
							
								      });
							 | 
						||
| 
								 | 
							
								      flat = flat.map(({xs, ys}) => ({xs:xs.map(v => (v-minX)/(maxX - minX)),ys:ys.map(v => (v-minY)/(maxY-minY))}));
							 | 
						||
| 
								 | 
							
								      let batch = flat.batch(32);
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      /* build and train model */
							 | 
						||
| 
								 | 
							
								      const model = sequential();
							 | 
						||
| 
								 | 
							
								      model.add(layers.dense({inputShape: [1], units: 1}));
							 | 
						||
| 
								 | 
							
								      model.compile({ optimizer: train.sgd(0.000001), loss: 'meanSquaredError' });
							 | 
						||
| 
								 | 
							
								      await model.fitDataset(batch, { epochs: 100, callbacks: { onEpochEnd: async (epoch, logs) => {
							 | 
						||
| 
								 | 
							
								        setOutput(`${epoch}:${logs?.loss}`);
							 | 
						||
| 
								 | 
							
								      }}});
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      /* predict values */
							 | 
						||
| 
								 | 
							
								      const inp = linspace(0, 1, 9);
							 | 
						||
| 
								 | 
							
								      const pred = model.predict(inp) as Tensor<Rank>;
							 | 
						||
| 
								 | 
							
								      const xs = await inp.dataSync(), ys = await pred.dataSync();
							 | 
						||
| 
								 | 
							
								      setResults(Array.from(xs).map((x, i) => [ x * (maxX - minX) + minX, ys[i] * (maxY - minY) + minY ]));
							 | 
						||
| 
								 | 
							
								      setOutput("");
							 | 
						||
| 
								 | 
							
								    } catch(e) { setOutput(`ERROR: ${String(e)}`); } finally { setDisabled(false);}
							 | 
						||
| 
								 | 
							
								  }, []);
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  return ( <>
							 | 
						||
| 
								 | 
							
								    <button onclick={doit} disabled={disabled}>Click to run</button><br/>
							 | 
						||
| 
								 | 
							
								    {output && <pre>{output}</pre> || <></>}
							 | 
						||
| 
								 | 
							
								    {results.length && <table><thead><tr><th>Horsepower</th><th>MPG</th></tr></thead><tbody>
							 | 
						||
| 
								 | 
							
								    {results.map((r,i) => <tr key={i}><td>{r[0]}</td><td>{r[1].toFixed(2)}</td></tr>)}
							 | 
						||
| 
								 | 
							
								    </tbody></table> || <></>}
							 | 
						||
| 
								 | 
							
								  </> );
							 | 
						||
| 
								 | 
							
								}
							 |