Skip to content

Commit

Permalink
Using keys to preserve values between reloads (#8056)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* rev pn

* changes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
4 people committed Apr 25, 2024
1 parent 0a42e96 commit 2e469a5
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 150 deletions.
7 changes: 7 additions & 0 deletions .changeset/heavy-crabs-ring.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/app": minor
"@gradio/client": minor
"gradio": minor
---

feat:Using keys to preserve values between reloads
7 changes: 6 additions & 1 deletion client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export class Client {
pending_diff_streams: Record<string, any[][]> = {};
event_callbacks: Record<string, () => Promise<void>> = {};
unclosed_events: Set<string> = new Set();
heartbeat_event: EventSource | null = null;

fetch_implementation(
input: RequestInfo | URL,
Expand Down Expand Up @@ -129,7 +130,7 @@ export class Client {
const heartbeat_url = new URL(
`${this.config.root}/heartbeat/${this.session_hash}`
);
this.eventSource_factory(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
this.heartbeat_event = this.eventSource_factory(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540

if (this.config.space_id && this.options.hf_token) {
this.jwt = await get_jwt(
Expand Down Expand Up @@ -157,6 +158,10 @@ export class Client {
return client;
}

close(): void {
this.heartbeat_event?.close();
}

static async duplicate(
app_reference: string,
options: DuplicateOptions = {}
Expand Down
7 changes: 7 additions & 0 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
self.state_session_capacity = 10000
self.temp_files: set[str] = set()
self.GRADIO_CACHE = get_upload_folder()
self.key: int | str | None = None
# Keep tracks of files that should not be deleted when the delete_cache parmaeter is set
# These files are the default value of the component and files that are used in examples
self.keep_in_cache = set()
Expand Down Expand Up @@ -640,6 +641,7 @@ def get_layout(block):
"props": utils.delete_none(props),
"skip_api": block.skip_api,
"component_class_id": getattr(block, "component_class_id", None),
"key": block.key,
}
if not block.skip_api:
block_config["api_info"] = block.api_info() # type: ignore
Expand Down Expand Up @@ -2162,6 +2164,11 @@ def reverse(text):
self.validate_queue_settings()
self.max_file_size = utils._parse_file_size(max_file_size)

if self.dev_mode:
for block in self.blocks.values():
if block.key is None:
block.key = f"__{block._id}__"

self.config = self.get_config_file()
self.max_threads = max_threads
self._queue.max_thread_count = max_threads
Expand Down
4 changes: 2 additions & 2 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,11 @@ async def reload_checker(request: fastapi.Request):

if app.change_event and app.change_event.is_set():
app.change_event.clear()
yield """data: CHANGE\n\n"""
yield """event: reload\ndata: {}\n\n"""

await asyncio.sleep(check_rate)
if time.perf_counter() - last_heartbeat > heartbeat_rate:
yield """data: HEARTBEAT\n\n"""
yield """event: heartbeat\ndata: {}\n\n"""
last_heartbeat = time.time()

return StreamingResponse(
Expand Down
57 changes: 43 additions & 14 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,13 @@ class BaseReloader(ABC):
def running_app(self) -> App:
pass

def queue_changed(self, demo: Blocks):
return (
hasattr(self.running_app.blocks, "_queue") and not hasattr(demo, "_queue")
) or (
not hasattr(self.running_app.blocks, "_queue") and hasattr(demo, "_queue")
)

def swap_blocks(self, demo: Blocks):
assert self.running_app.blocks # noqa: S101
# Copy over the blocks to get new components and events but
# not a new queue
self.running_app.blocks._queue.block_fns = demo.fns
demo._queue = self.running_app.blocks._queue
demo.max_file_size = self.running_app.blocks.max_file_size
self.running_app.state_holder.reset(demo)
self.running_app.blocks = demo
demo._queue.reload()
Expand Down Expand Up @@ -155,7 +149,11 @@ def alert_change(self):
self.change_event.set()

def swap_blocks(self, demo: Blocks):
old_blocks = self.running_app.blocks
super().swap_blocks(demo)
if old_blocks:
reassign_keys(old_blocks, demo)
demo.config = demo.get_config_file()
self.alert_change()


Expand Down Expand Up @@ -285,17 +283,48 @@ def iter_py_files() -> Iterator[Path]:
mtimes = {}
continue
demo = getattr(module, reloader.demo_name)
if reloader.queue_changed(demo):
print(
"Reloading failed. The new demo has a queue and the old one doesn't (or vice versa). "
"Please launch your demo again"
)
else:
reloader.swap_blocks(demo)
reloader.swap_blocks(demo)
mtimes = {}
time.sleep(0.05)


def reassign_keys(old_blocks: Blocks, new_blocks: Blocks):
from gradio.blocks import BlockContext

assigned_keys = [
block.key for block in new_blocks.children if block.key is not None
]

def reassign_context_keys(
old_context: BlockContext | None, new_context: BlockContext
):
for i, new_block in enumerate(new_context.children):
if old_context and i < len(old_context.children):
old_block = old_context.children[i]
else:
old_block = None
if new_block.key is None:
if (
old_block.__class__ == new_block.__class__
and old_block is not None
and old_block.key not in assigned_keys
):
new_block.key = old_block.key
else:
new_block.key = f"__{new_block._id}__"

if isinstance(new_block, BlockContext):
if (
isinstance(old_block, BlockContext)
and old_block.__class__ == new_block.__class__
):
reassign_context_keys(old_block, new_block)
else:
reassign_context_keys(None, new_block)

reassign_context_keys(old_blocks, new_blocks)


def colab_check() -> bool:
"""
Check if interface is launching from Google Colab
Expand Down
10 changes: 8 additions & 2 deletions js/app/src/Blocks.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
import logo from "./images/logo.svg";
import api_logo from "./api_docs/img/api-logo.svg";
import { create_components, AsyncFunction } from "./init";
import {
create_components,
AsyncFunction,
restore_keyed_values
} from "./init";
setupi18n();
export let root: string;
export let components: ComponentMeta[];
let old_components: ComponentMeta[] = components;
export let layout: LayoutNode;
export let dependencies: Dependency[];
export let title = "Gradio";
Expand Down Expand Up @@ -60,7 +65,8 @@
app,
options: {
fill_height
}
},
callback: () => restore_keyed_values(old_components, components)
});
$: {
Expand Down
25 changes: 12 additions & 13 deletions js/app/src/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -307,21 +307,20 @@
const { host } = new URL(api_url);
let url = new URL(`http://${host}/dev/reload`);
eventSource = new EventSource(url);
eventSource.onmessage = async function (event) {
if (event.data === "CHANGE") {
app = await Client.connect(api_url, {
status_callback: handle_status
});
if (!app.config) {
throw new Error("Could not resolve app config");
}
eventSource.addEventListener("reload", async (event) => {
app.close();
app = await Client.connect(api_url, {
status_callback: handle_status
});
config = app.config;
window.__gradio_space__ = config.space_id;
await mount_custom_css(config.css);
if (!app.config) {
throw new Error("Could not resolve app config");
}
};
config = app.config;
window.__gradio_space__ = config.space_id;
await mount_custom_css(config.css);
});
}, 200);
}
});
Expand Down
31 changes: 29 additions & 2 deletions js/app/src/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export function create_components(): {
options: {
fill_height: boolean;
};
callback?: () => void;
}) => void;
} {
let _component_map: Map<number, ComponentMeta>;
Expand All @@ -61,7 +62,8 @@ export function create_components(): {
layout,
dependencies,
root,
options
options,
callback
}: {
app: client_return;
components: ComponentMeta[];
Expand All @@ -71,6 +73,7 @@ export function create_components(): {
options: {
fill_height: boolean;
};
callback?: () => void;
}): void {
app = _app;
_components = components;
Expand All @@ -89,7 +92,8 @@ export function create_components(): {
has_modes: false,
instance: null as unknown as ComponentMeta["instance"],
component: null as unknown as ComponentMeta["component"],
component_class_id: ""
component_class_id: "",
key: null
};

components.push(_rootNode);
Expand Down Expand Up @@ -120,6 +124,9 @@ export function create_components(): {

walk_layout(layout, root).then(() => {
layout_store.set(_rootNode);
if (callback) {
callback();
}
});
}

Expand Down Expand Up @@ -481,3 +488,23 @@ export function preload_all_components(

return constructor_map;
}

export const restore_keyed_values = (
old_components: ComponentMeta[],
new_components: ComponentMeta[]
): void => {
let component_values_by_key: Record<string | number, ComponentMeta> = {};
old_components.forEach((component) => {
if (component.key) {
component_values_by_key[component.key] = component;
}
});
new_components.forEach((component) => {
if (component.key) {
const old_component = component_values_by_key[component.key];
if (old_component) {
component.props.value = old_component.props.value;
}
}
});
};
1 change: 1 addition & 0 deletions js/app/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export interface ComponentMeta {
children?: ComponentMeta[];
value?: any;
component_class_id: string;
key: string | number | null;
}

/** Dictates whether a dependency is continous and/or a generator */
Expand Down
Loading

0 comments on commit 2e469a5

Please sign in to comment.