Skip to content

Commit c645e9c

Browse files
committed
fix bad merge
1 parent d0e5cd9 commit c645e9c

3 files changed

Lines changed: 31 additions & 3 deletions

File tree

marimo/_session/extensions/extensions.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import asyncio
11+
from enum import Enum
1112
from typing import TYPE_CHECKING, Optional
1213

1314
from marimo import _loggers
@@ -97,6 +98,11 @@ def _stop(self) -> None:
9798
self.heartbeat_task.cancel()
9899

99100

101+
class CacheMode(Enum):
102+
READ = "read"
103+
READ_WRITE = "write"
104+
105+
100106
class CachingExtension(SessionExtension, SessionEventListener):
101107
"""Extension for caching session state to disk.
102108
@@ -111,15 +117,18 @@ def __init__(
111117
*,
112118
enabled: bool,
113119
interval: int = SESSION_CACHE_INTERVAL_SECONDS,
120+
mode: CacheMode = CacheMode.READ_WRITE,
114121
) -> None:
115122
"""Initialize the caching extension.
116123
117124
Args:
118125
enabled: Whether to enable caching
119126
interval: How often to write cache (in seconds)
127+
mode: Whether to read cache only or read/write.
120128
"""
121129
self.interval = interval
122130
self.enabled = enabled
131+
self.mode = mode
123132
self.session_cache_manager: Optional[SessionCacheManager] = None
124133
self.event_bus: Optional[SessionEventBus] = None
125134

@@ -165,7 +174,8 @@ def on_attach(self, session: Session, event_bus: SessionEventBus) -> None:
165174
)
166175

167176
# Start the background task to write the session view to disk
168-
self.session_cache_manager.start()
177+
if self.mode is CacheMode.READ_WRITE:
178+
self.session_cache_manager.start()
169179

170180
def on_detach(self) -> None:
171181
"""Stop cache manager when detached."""
@@ -179,6 +189,8 @@ async def on_session_notebook_renamed(
179189
) -> None:
180190
"""Rename the path for the cache manager."""
181191
del old_path
192+
if self.mode is not CacheMode.READ_WRITE:
193+
return None
182194
path = session.app_file_manager.path
183195
if self.session_cache_manager and path:
184196
self.session_cache_manager.rename_path(path)

marimo/_session/session.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from marimo._session.consumer import SessionConsumer
3030
from marimo._session.events import SessionEventBus
3131
from marimo._session.extensions.extensions import (
32+
CacheMode,
3233
CachingExtension,
3334
HeartbeatExtension,
3435
LoggingExtension,
@@ -171,12 +172,22 @@ def create(
171172
redirect_console_to_browser=redirect_console_to_browser,
172173
)
173174

175+
if mode == SessionMode.EDIT:
176+
cache_enabled = not auto_instantiate
177+
cache_mode = CacheMode.READ_WRITE
178+
else:
179+
cache_enabled = config_manager.get_config()["runtime"].get(
180+
"serve_cached_sessions_in_apps", False
181+
)
182+
cache_mode = CacheMode.READ
183+
174184
extensions = [
175185
*(extensions or []),
176186
LoggingExtension(),
177187
HeartbeatExtension(),
178188
CachingExtension(
179-
enabled=not auto_instantiate and mode == SessionMode.EDIT
189+
enabled=cache_enabled,
190+
mode=cache_mode,
180191
),
181192
NotificationListenerExtension(
182193
kernel_manager=kernel_manager, queue_manager=queue_manager

tests/_session/app_host/test_app_host.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def test_create_session_uses_per_client_session_id(self) -> None:
420420
"""
421421
from unittest.mock import Mock, patch
422422

423+
from marimo._config.config import DEFAULT_CONFIG
423424
from marimo._session.app_host import AppHostContext
424425
from marimo._session.app_host.host import AppHost
425426
from marimo._session.model import SessionMode
@@ -444,7 +445,11 @@ def test_create_session_uses_per_client_session_id(self) -> None:
444445
app_file_manager=Mock(path="/tmp/test_app.py"),
445446
config_manager=Mock(
446447
with_overrides=Mock(
447-
return_value=Mock(get_config=Mock(return_value={}))
448+
return_value=Mock(
449+
get_config=Mock(
450+
return_value=DEFAULT_CONFIG
451+
)
452+
)
448453
)
449454
),
450455
virtual_files_supported=True,

0 commit comments

Comments
 (0)