diff --git a/src/LLMPicker.tsx b/src/LLMPicker.tsx
index 44fbdfe..eaafeaa 100644
--- a/src/LLMPicker.tsx
+++ b/src/LLMPicker.tsx
@@ -7,17 +7,20 @@ import {
Select,
Stack,
StackProps,
+ Text,
+ Title,
Tooltip,
} from "@mantine/core";
+import { IconCpu, IconRobotFace } from "@tabler/icons-react";
-import { IconRobotFace } from "@tabler/icons-react";
import { useDisclosure } from "@mantine/hooks";
import { useMLEngine } from "./MLEngineContext";
function LLMPicker(props: StackProps) {
const [opened, { open, close }] = useDisclosure(false);
- const { loadingModel, activeModel, selectModel, modelList } = useMLEngine();
+ const { loadingModel, activeModel, selectModel, modelList, gpuVendor } =
+ useMLEngine();
return (
@@ -61,17 +64,31 @@ function LLMPicker(props: StackProps) {
radius="md"
opened={opened}
onClose={close}
- title="Select model"
+ title={
+
+
+ LLM
+
+ }
position="bottom"
size={200}
>
-
);
diff --git a/src/MLEngineContext.tsx b/src/MLEngineContext.tsx
index 2484571..cab5237 100644
--- a/src/MLEngineContext.tsx
+++ b/src/MLEngineContext.tsx
@@ -54,6 +54,7 @@ const modelList = [
type MLEngineContext = {
activeModel: string | null;
selectedModel: string | null;
+ gpuVendor: string | null;
loadingModel: {
name: string;
progress: number;
@@ -69,6 +70,7 @@ type MLEngineContext = {
const MLEngineContext = createContext({
activeModel: null,
selectedModel: null,
+ gpuVendor: null,
loadingModel: null,
engine: { current: null },
selectModel: () => {},
@@ -81,6 +83,7 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) {
const [loadingModel, setLoadingModel] = useState(null);
const [loadingProgress, setLoadingProgress] = useState(null);
const [runningModel, setRunningModel] = useState(null);
+ const [gpuVendor, setGpuVendor] = useState(null);
const [selectedModel, setSelectedModel] = useLocalStorage({
key: "modelId",
@@ -95,21 +98,20 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) {
initProgress
) => {
setLoadingProgress(initProgress.progress);
-
- if (
- initProgress.progress === 1 &&
- initProgress.text.startsWith("Finish loading")
- ) {
- setRunningModel(selectedModel);
- setLoadingModel(null);
- setLoadingProgress(null);
- }
};
- engine.current = await CreateMLCEngine(
- selectedModel,
- { initProgressCallback: initProgressCallback } // engineConfig
- );
+ engine.current = await CreateMLCEngine(selectedModel, {
+ initProgressCallback: initProgressCallback,
+ });
+
+ setRunningModel(selectedModel);
+ setLoadingModel(null);
+ setLoadingProgress(null);
+
+ const gpuVendor = await engine.current?.getGPUVendor();
+ if (gpuVendor) {
+ setGpuVendor(gpuVendor);
+ }
})();
}
}, [
@@ -119,6 +121,7 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) {
setRunningModel,
setLoadingModel,
setLoadingProgress,
+ setGpuVendor,
]);
return (
@@ -131,6 +134,7 @@ export function MLEngineContextProvider({ children }: { children: ReactNode }) {
: null,
activeModel: runningModel,
selectedModel,
+ gpuVendor,
selectModel: setSelectedModel,
modelList,
}}