diff --git a/src/LLMPicker.tsx b/src/LLMPicker.tsx index 44fbdfe..eaafeaa 100644 --- a/src/LLMPicker.tsx +++ b/src/LLMPicker.tsx @@ -7,17 +7,20 @@ import { Select, Stack, StackProps, + Text, + Title, Tooltip, } from "@mantine/core"; +import { IconCpu, IconRobotFace } from "@tabler/icons-react"; -import { IconRobotFace } from "@tabler/icons-react"; import { useDisclosure } from "@mantine/hooks"; import { useMLEngine } from "./MLEngineContext"; function LLMPicker(props: StackProps) { const [opened, { open, close }] = useDisclosure(false); - const { loadingModel, activeModel, selectModel, modelList } = useMLEngine(); + const { loadingModel, activeModel, selectModel, modelList, gpuVendor } = + useMLEngine(); return ( @@ -61,17 +64,31 @@ function LLMPicker(props: StackProps) { radius="md" opened={opened} onClose={close} - title="Select model" + title={ + + + LLM + + } position="bottom" size={200} > - val && selectModel(val)} + searchable + clearable + hiddenFrom="sm" + /> + {gpuVendor && ( + + + {gpuVendor} + + )} + ); diff --git a/src/MLEngineContext.tsx b/src/MLEngineContext.tsx index 2484571..cab5237 100644 --- a/src/MLEngineContext.tsx +++ b/src/MLEngineContext.tsx @@ -54,6 +54,7 @@ const modelList = [ type MLEngineContext = { activeModel: string | null; selectedModel: string | null; + gpuVendor: string | null; loadingModel: { name: string; progress: number; @@ -69,6 +70,7 @@ type MLEngineContext = { const MLEngineContext = createContext({ activeModel: null, selectedModel: null, + gpuVendor: null, loadingModel: null, engine: { current: null }, selectModel: () => {}, @@ -81,6 +83,7 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) { const [loadingModel, setLoadingModel] = useState(null); const [loadingProgress, setLoadingProgress] = useState(null); const [runningModel, setRunningModel] = useState(null); + const [gpuVendor, setGpuVendor] = useState(null); const [selectedModel, setSelectedModel] = useLocalStorage({ key: "modelId", @@ -95,21 +98,20 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) { initProgress ) => { setLoadingProgress(initProgress.progress); - - if ( - initProgress.progress === 1 && - initProgress.text.startsWith("Finish loading") - ) { - setRunningModel(selectedModel); - setLoadingModel(null); - setLoadingProgress(null); - } }; - engine.current = await CreateMLCEngine( - selectedModel, - { initProgressCallback: initProgressCallback } // engineConfig - ); + engine.current = await CreateMLCEngine(selectedModel, { + initProgressCallback: initProgressCallback, + }); + + setRunningModel(selectedModel); + setLoadingModel(null); + setLoadingProgress(null); + + const gpuVendor = await engine.current?.getGPUVendor(); + if (gpuVendor) { + setGpuVendor(gpuVendor); + } })(); } }, [ @@ -119,6 +121,7 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) { setRunningModel, setLoadingModel, setLoadingProgress, + setGpuVendor, ]); return ( @@ -131,6 +134,7 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) { : null, activeModel: runningModel, selectedModel, + gpuVendor, selectModel: setSelectedModel, modelList, }}