Skip to content

Commit 79b24ba

Browse files
Prototype of a stepper composable helper for making backstacks.
1 parent 9056cc7 commit 79b24ba

File tree

2 files changed

+316
-0
lines changed

2 files changed

+316
-0
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
package com.squareup.workflow1.compose
2+
3+
import androidx.collection.ScatterMap
4+
import androidx.collection.mutableScatterMapOf
5+
import androidx.compose.runtime.Composable
6+
import androidx.compose.runtime.getValue
7+
import androidx.compose.runtime.mutableStateListOf
8+
import androidx.compose.runtime.mutableStateOf
9+
import androidx.compose.runtime.remember
10+
import androidx.compose.runtime.setValue
11+
import androidx.compose.runtime.snapshots.Snapshot
12+
import androidx.compose.runtime.snapshots.StateObject
13+
import androidx.compose.runtime.snapshots.StateRecord
14+
import androidx.compose.runtime.snapshots.readable
15+
import androidx.compose.runtime.snapshots.writable
16+
17+
/**
18+
* Composes [content] and returns its return value in a list.
19+
*
20+
* Every time [content] calls [Stepper.advance] its argument is passed to [advance] and [advance] is
21+
* expected to update some states that are read inside [content]. The current values of all states
22+
* changed by [advance] are saved into a frame and pushed onto the backstack along with the last
23+
* value returned by [content].
24+
*
25+
* When [Stepper.goBack] is called the last frame is popped and all the states that were written by
26+
* [advance] are restored before recomposing [content].
27+
*
28+
* @sample com.squareup.workflow1.compose.StepperDemo
29+
*/
30+
@Composable
31+
public fun <T, R> stepper(
32+
advance: (T) -> Unit,
33+
content: @Composable Stepper<T, R>.() -> R
34+
): List<R> {
35+
// TODO figure out how to support rememberSaveable
36+
val stepperImpl = remember { StepperImpl<T, R>(advance = advance) }
37+
stepperImpl.advance = advance
38+
stepperImpl.lastRendering = content(stepperImpl)
39+
return stepperImpl.renderings
40+
}
41+
42+
/**
43+
* Composes [content] and returns its return value in a list. Every time [content] calls
44+
* [Stepper.advance] the current values of all states changed by the `toState` block are
45+
* saved into a frame and pushed onto the backstack along with the last value returned by [content].
46+
* When [Stepper.goBack] is called the last frame is popped and all the states that were
47+
* written by the `toState` block are restored before recomposing [content].
48+
*
49+
* This is an overload of [stepper] that makes it easier to specify the state update function when
50+
* calling [Stepper.advance] instead of defining it ahead of time.
51+
*
52+
* @sample com.squareup.workflow1.compose.StepperInlineDemo
53+
*/
54+
// Impl note: Inline since this is just syntactic sugar, no reason to generate bytecode/API for it.
55+
@Suppress("NOTHING_TO_INLINE")
56+
@Composable
57+
public inline fun <R> stepper(
58+
noinline content: @Composable Stepper<() -> Unit, R>.() -> R
59+
): List<R> = stepper(advance = { it() }, content = content)
60+
61+
public interface Stepper<T, R> {
62+
63+
/** The (possibly empty) stack of steps that came before the current one. */
64+
val previousSteps: List<Step<R>>
65+
66+
/**
67+
* Pushes a new frame onto the backstack with the current state and then runs [toState].
68+
*/
69+
fun advance(toState: T)
70+
71+
/**
72+
* Pops the last frame off the backstack and restores its state.
73+
*
74+
* @return False if the stack was empty (i.e. this is a noop).
75+
*/
76+
fun goBack(): Boolean
77+
}
78+
79+
public interface Step<T> {
80+
/** The last rendering produced by this step. */
81+
val rendering: T
82+
83+
/**
84+
* Runs [block] inside a snapshot such that the step state is set to its saved values from this
85+
* step. The snapshot is read-only, so writing to any snapshot state objects will throw.
86+
*/
87+
fun <R> peekStateFromStep(block: () -> R): R
88+
}
89+
90+
private class StepperImpl<T, R>(
91+
advance: (T) -> Unit
92+
) : Stepper<T, R> {
93+
var advance: (T) -> Unit by mutableStateOf(advance)
94+
private val savePoints = mutableStateListOf<SavePoint>()
95+
var lastRendering by mutableStateOf<Any?>(NO_RENDERING)
96+
97+
val renderings: List<R>
98+
get() = buildList(capacity = savePoints.size + 1) {
99+
savePoints.mapTo(this) { it.rendering }
100+
@Suppress("UNCHECKED_CAST")
101+
add(lastRendering as R)
102+
}
103+
104+
override val previousSteps: List<Step<R>>
105+
get() = savePoints
106+
107+
override fun advance(toState: T) {
108+
check(lastRendering !== NO_RENDERING) { "advance called before first composition" }
109+
110+
// Take an outer snapshot so all the state mutations in withState get applied atomically with
111+
// our internal state update (to savePoints).
112+
Snapshot.withMutableSnapshot {
113+
val savedRecords = mutableScatterMapOf<StateObject, StateRecord?>()
114+
val snapshot = Snapshot.takeMutableSnapshot(
115+
writeObserver = {
116+
// Don't save the value of the object yet, we want the value _before_ the write, so we
117+
// need to read it outside this inner snapshot.
118+
savedRecords[it as StateObject] = null
119+
}
120+
)
121+
try {
122+
// Record what state objects are written by the block.
123+
snapshot.enter { this.advance.invoke(toState) }
124+
125+
// Save the _current_ values of those state objects so we can restore them later.
126+
// TODO Need to think more about which state objects need to be saved and restored for a
127+
// particular frame. E.g. probably we should track all objects that were written for the
128+
// current frame, and save those as well, even if they're not written by the _next_ frame.
129+
savedRecords.forEachKey { stateObject ->
130+
savedRecords[stateObject] = stateObject.copyCurrentRecord()
131+
}
132+
133+
// This should never fail since we're already in a snapshot and no other state has been
134+
// written by this point, but check just in case.
135+
val advanceApplyResult = snapshot.apply()
136+
if (advanceApplyResult.succeeded) {
137+
// This cast is fine, we know we've assigned a non-null value to all entries.
138+
@Suppress("UNCHECKED_CAST")
139+
savePoints += SavePoint(
140+
savedRecords = savedRecords as ScatterMap<StateObject, StateRecord>,
141+
rendering = lastRendering as R,
142+
)
143+
}
144+
// If !succeeded, throw the standard error.
145+
advanceApplyResult.check()
146+
} finally {
147+
snapshot.dispose()
148+
}
149+
}
150+
}
151+
152+
override fun goBack(): Boolean {
153+
Snapshot.withMutableSnapshot {
154+
if (savePoints.isEmpty()) return false
155+
val toRestore = savePoints.removeAt(savePoints.lastIndex)
156+
157+
// Restore all state objects' saved values.
158+
toRestore.restoreState()
159+
160+
// Don't need to restore the last rendering, it will be computed fresh by the imminent
161+
// recomposition.
162+
}
163+
return true
164+
}
165+
166+
/**
167+
* Returns a copy of the current readable record of this state object. A copy is needed since
168+
* active records can be mutated by other snapshots.
169+
*/
170+
private fun StateObject.copyCurrentRecord(): StateRecord {
171+
val record = firstStateRecord.readable(this)
172+
// Records can be mutated in other snapshots, so create a copy.
173+
return record.create().apply { assign(record) }
174+
}
175+
176+
/**
177+
* Sets the value of this state object to a [record] that was previously copied via
178+
* [copyCurrentRecord].
179+
*/
180+
private fun StateObject.restoreRecord(record: StateRecord) {
181+
firstStateRecord.writable(this) { assign(record) }
182+
}
183+
184+
private inner class SavePoint(
185+
val savedRecords: ScatterMap<StateObject, StateRecord>,
186+
override val rendering: R,
187+
) : Step<R> {
188+
override fun <R> peekStateFromStep(block: () -> R): R {
189+
// Need a mutable snapshot to restore state.
190+
val restoreSnapshot = Snapshot.takeMutableSnapshot()
191+
try {
192+
restoreSnapshot.enter {
193+
restoreState()
194+
195+
// Now take a read-only snapshot to enforce contract.
196+
val readOnlySnapshot = Snapshot.takeSnapshot()
197+
try {
198+
return readOnlySnapshot.enter(block)
199+
} finally {
200+
readOnlySnapshot.dispose()
201+
}
202+
}
203+
} finally {
204+
restoreSnapshot.dispose()
205+
}
206+
}
207+
208+
fun restoreState() {
209+
savedRecords.forEach { stateObject, record ->
210+
stateObject.restoreRecord(record)
211+
}
212+
}
213+
}
214+
215+
companion object {
216+
val NO_RENDERING = Any()
217+
}
218+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package com.squareup.workflow1.compose
2+
3+
import androidx.compose.runtime.Composable
4+
import androidx.compose.runtime.getValue
5+
import androidx.compose.runtime.mutableStateOf
6+
import androidx.compose.runtime.saveable.rememberSaveable
7+
import androidx.compose.runtime.setValue
8+
import androidx.compose.ui.util.fastJoinToString
9+
import com.squareup.workflow1.compose.DemoStep.ONE
10+
import com.squareup.workflow1.compose.DemoStep.THREE
11+
import com.squareup.workflow1.compose.DemoStep.TWO
12+
import com.squareup.workflow1.compose.Screen.ScreenOne
13+
import com.squareup.workflow1.compose.Screen.ScreenThree
14+
import com.squareup.workflow1.compose.Screen.ScreenTwo
15+
16+
internal enum class DemoStep {
17+
ONE,
18+
TWO,
19+
THREE,
20+
}
21+
22+
internal sealed interface Screen {
23+
val message: String
24+
25+
data class ScreenOne(
26+
override val message: String,
27+
val onNextClicked: () -> Unit,
28+
) : Screen
29+
30+
data class ScreenTwo(
31+
override val message: String,
32+
val onNextClicked: () -> Unit,
33+
val onBack: () -> Unit,
34+
) : Screen
35+
36+
data class ScreenThree(
37+
override val message: String,
38+
val onBack: () -> Unit,
39+
) : Screen
40+
}
41+
42+
@Composable
43+
internal fun StepperDemo() {
44+
var step by rememberSaveable { mutableStateOf(ONE) }
45+
println("step=$step")
46+
47+
val stack: List<Screen> = stepper(advance = { step = it }) {
48+
val breadcrumbs = previousSteps.fastJoinToString(separator = " > ") { it.rendering.message }
49+
when (step) {
50+
ONE -> ScreenOne(
51+
message = "Step one",
52+
onNextClicked = { advance(TWO) },
53+
)
54+
55+
TWO -> ScreenTwo(
56+
message = "Step two",
57+
onNextClicked = { advance(THREE) },
58+
onBack = { goBack() },
59+
)
60+
61+
THREE -> ScreenThree(
62+
message = "Step three",
63+
onBack = { goBack() },
64+
)
65+
}
66+
}
67+
68+
println("stack = ${stack.fastJoinToString()}")
69+
}
70+
71+
@Composable
72+
internal fun StepperInlineDemo() {
73+
var step by rememberSaveable { mutableStateOf(ONE) }
74+
println("step=$step")
75+
76+
val stack: List<Screen> = stepper {
77+
val breadcrumbs = previousSteps.fastJoinToString(separator = " > ") { it.rendering.message }
78+
when (step) {
79+
ONE -> ScreenOne(
80+
message = "Step one",
81+
onNextClicked = { advance { step = TWO } },
82+
)
83+
84+
TWO -> ScreenTwo(
85+
message = "Step two",
86+
onNextClicked = { advance { step = THREE } },
87+
onBack = { goBack() },
88+
)
89+
90+
THREE -> ScreenThree(
91+
message = "Step three",
92+
onBack = { goBack() },
93+
)
94+
}
95+
}
96+
97+
println("stack = ${stack.fastJoinToString()}")
98+
}

0 commit comments

Comments
 (0)