diff --git a/src/MLEngineContext.tsx b/src/MLEngineContext.tsx index cab5237..c1e1891 100644 --- a/src/MLEngineContext.tsx +++ b/src/MLEngineContext.tsx @@ -94,15 +94,20 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) { if (selectedModel && runningModel !== selectedModel) { (async () => { setLoadingModel(selectedModel); + const initProgressCallback: InitProgressCallback = async ( initProgress ) => { setLoadingProgress(initProgress.progress); }; - engine.current = await CreateMLCEngine(selectedModel, { - initProgressCallback: initProgressCallback, - }); + if (!engine.current) { + engine.current = await CreateMLCEngine(selectedModel, { + initProgressCallback: initProgressCallback, + }); + } else { + await engine.current.reload(selectedModel); + } setRunningModel(selectedModel); setLoadingModel(null); @@ -113,6 +118,8 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) { setGpuVendor(gpuVendor); } })(); + } else if (!selectedModel && engine.current) { + engine.current.unload(); } }, [ engine,