]> git.ipfire.org Git - thirdparty/vuejs/core.git/commitdiff
fix(sfc): fix `<script setup>` async context preservation logic
authorEvan You <yyx990803@gmail.com>
Tue, 6 Jul 2021 18:31:53 +0000 (14:31 -0400)
committerEvan You <yyx990803@gmail.com>
Tue, 6 Jul 2021 18:31:53 +0000 (14:31 -0400)
fix #4050

packages/compiler-sfc/__tests__/compileScript.spec.ts
packages/compiler-sfc/src/compileScript.ts
packages/runtime-core/__tests__/apiSetupHelpers.spec.ts
packages/runtime-core/src/apiSetupHelpers.ts

index 1da416b36fb6d6a37df897f2e403d594b6944012..d75927d33bb89dad2c51001f69cbd7fce955ef95 100644 (file)
@@ -847,6 +847,9 @@ const emit = defineEmits(['a', 'b'])
       const { content } = compile(`<script setup>${code}</script>`, {
         refSugar: true
       })
+      if (shouldAsync) {
+        expect(content).toMatch(`let __temp, __restore`)
+      }
       expect(content).toMatch(`${shouldAsync ? `async ` : ``}setup(`)
       if (typeof expected === 'string') {
         expect(content).toMatch(expected)
@@ -856,28 +859,35 @@ const emit = defineEmits(['a', 'b'])
     }
 
     test('expression statement', () => {
-      assertAwaitDetection(`await foo`, `await _withAsyncContext(foo)`)
+      assertAwaitDetection(
+        `await foo`,
+        `;(([__temp,__restore]=_withAsyncContext(()=>(foo))),__temp=await __temp,__restore())`
+      )
     })
 
     test('variable', () => {
       assertAwaitDetection(
         `const a = 1 + (await foo)`,
-        `1 + (await _withAsyncContext(foo))`
+        `1 + ((([__temp,__restore]=_withAsyncContext(()=>(foo))),__temp=await __temp,__restore(),__temp))`
       )
     })
 
     test('ref', () => {
       assertAwaitDetection(
         `ref: a = 1 + (await foo)`,
-        `1 + (await _withAsyncContext(foo))`
+        `1 + ((([__temp,__restore]=_withAsyncContext(()=>(foo))),__temp=await __temp,__restore(),__temp))`
       )
     })
 
     test('nested statements', () => {
       assertAwaitDetection(`if (ok) { await foo } else { await bar }`, code => {
         return (
-          code.includes(`await _withAsyncContext(foo)`) &&
-          code.includes(`await _withAsyncContext(bar)`)
+          code.includes(
+            `;(([__temp,__restore]=_withAsyncContext(()=>(foo))),__temp=await __temp,__restore())`
+          ) &&
+          code.includes(
+            `;(([__temp,__restore]=_withAsyncContext(()=>(bar))),__temp=await __temp,__restore())`
+          )
         )
       })
     })
index 94d351ab03cd2c3ff4c0f1e3f93589b3b60f9bab..e5a792d77c33ce499435382e98de920f891b30af 100644 (file)
@@ -32,7 +32,8 @@ import {
   LabeledStatement,
   CallExpression,
   RestElement,
-  TSInterfaceBody
+  TSInterfaceBody,
+  AwaitExpression
 } from '@babel/types'
 import { walk } from 'estree-walker'
 import { RawSourceMap } from 'source-map'
@@ -487,6 +488,25 @@ export function compileScript(
     })
   }
 
+  /**
+   * await foo()
+   * -->
+   * (([__temp, __restore] = withAsyncContext(() => foo())),__temp=await __temp,__restore(),__temp)
+   */
+  function processAwait(node: AwaitExpression, isStatement: boolean) {
+    s.overwrite(
+      node.start! + startOffset,
+      node.argument.start! + startOffset,
+      `${isStatement ? `;` : ``}(([__temp,__restore]=${helper(
+        `withAsyncContext`
+      )}(()=>(`
+    )
+    s.appendLeft(
+      node.end! + startOffset,
+      `))),__temp=await __temp,__restore()${isStatement ? `` : `,__temp`})`
+    )
+  }
+
   function processRefExpression(exp: Expression, statement: LabeledStatement) {
     if (exp.type === 'AssignmentExpression') {
       const { left, right } = exp
@@ -949,17 +969,13 @@ export function compileScript(
       node.type.endsWith('Statement')
     ) {
       ;(walk as any)(node, {
-        enter(node: Node) {
-          if (isFunction(node)) {
+        enter(child: Node, parent: Node) {
+          if (isFunction(child)) {
             this.skip()
           }
-          if (node.type === 'AwaitExpression') {
+          if (child.type === 'AwaitExpression') {
             hasAwait = true
-            s.prependRight(
-              node.argument.start! + startOffset,
-              helper(`withAsyncContext`) + `(`
-            )
-            s.appendLeft(node.argument.end! + startOffset, `)`)
+            processAwait(child, parent.type === 'ExpressionStatement')
           }
         }
       })
@@ -1151,6 +1167,11 @@ export function compileScript(
   if (propsIdentifier) {
     s.prependRight(startOffset, `\nconst ${propsIdentifier} = __props`)
   }
+  // inject temp variables for async context preservation
+  if (hasAwait) {
+    const any = isTS ? `:any` : ``
+    s.prependRight(startOffset, `\nlet __temp${any}, __restore${any}\n`)
+  }
 
   const destructureElements =
     hasDefineExposeCall || !options.inlineTemplate ? [`expose`] : []
index 5244a9a87b8d523d6e63a083f1a02778a1506941..3dc22e67ce5b459b5694cdc5687361ffe59569de 100644 (file)
@@ -119,12 +119,20 @@ describe('SFC <script setup> helpers', () => {
 
       const Comp = defineComponent({
         async setup() {
+          let __temp: any, __restore: any
+
           beforeInstance = getCurrentInstance()
-          const msg = await withAsyncContext(
-            new Promise(r => {
-              resolve = r
-            })
-          )
+
+          const msg = (([__temp, __restore] = withAsyncContext(
+            () =>
+              new Promise(r => {
+                resolve = r
+              })
+          )),
+          (__temp = await __temp),
+          __restore(),
+          __temp)
+
           // register the lifecycle after an await statement
           onMounted(spy)
           afterInstance = getCurrentInstance()
@@ -155,13 +163,18 @@ describe('SFC <script setup> helpers', () => {
 
       const Comp = defineComponent({
         async setup() {
+          let __temp: any, __restore: any
+
           beforeInstance = getCurrentInstance()
           try {
-            await withAsyncContext(
-              new Promise((r, rj) => {
-                reject = rj
-              })
+            ;[__temp, __restore] = withAsyncContext(
+              () =>
+                new Promise((_, rj) => {
+                  reject = rj
+                })
             )
+            __temp = await __temp
+            __restore()
           } catch (e) {
             // ignore
           }
@@ -206,11 +219,20 @@ describe('SFC <script setup> helpers', () => {
 
       const Comp = defineComponent({
         async setup() {
+          let __temp: any, __restore: any
+
           beforeInstance = getCurrentInstance()
+
           // first await
-          await withAsyncContext(Promise.resolve())
+          ;[__temp, __restore] = withAsyncContext(() => Promise.resolve())
+          __temp = await __temp
+          __restore()
+
           // setup exit, instance set to null, then resumed
-          await withAsyncContext(doAsyncWork())
+          ;[__temp, __restore] = withAsyncContext(() => doAsyncWork())
+          __temp = await __temp
+          __restore()
+
           afterInstance = getCurrentInstance()
           return () => {
             resolve()
@@ -237,8 +259,13 @@ describe('SFC <script setup> helpers', () => {
 
       const Comp = defineComponent({
         async setup() {
-          await withAsyncContext(Promise.resolve())
-          await withAsyncContext(Promise.reject())
+          let __temp: any, __restore: any
+          ;[__temp, __restore] = withAsyncContext(() => Promise.resolve())
+          __temp = await __temp
+          __restore()
+          ;[__temp, __restore] = withAsyncContext(() => Promise.reject())
+          __temp = await __temp
+          __restore()
         },
         render() {}
       })
@@ -256,6 +283,42 @@ describe('SFC <script setup> helpers', () => {
       expect(getCurrentInstance()).toBeNull()
     })
 
+    // #4050
+    test('race conditions', async () => {
+      const uids = {
+        one: { before: NaN, after: NaN },
+        two: { before: NaN, after: NaN }
+      }
+
+      const Comp = defineComponent({
+        props: ['name'],
+        async setup(props: { name: 'one' | 'two' }) {
+          let __temp: any, __restore: any
+
+          uids[props.name].before = getCurrentInstance()!.uid
+          ;[__temp, __restore] = withAsyncContext(() => Promise.resolve())
+          __temp = await __temp
+          __restore()
+
+          uids[props.name].after = getCurrentInstance()!.uid
+          return () => ''
+        }
+      })
+
+      const app = createApp(() =>
+        h(Suspense, () =>
+          h('div', [h(Comp, { name: 'one' }), h(Comp, { name: 'two' })])
+        )
+      )
+      const root = nodeOps.createElement('div')
+      app.mount(root)
+
+      await new Promise(r => setTimeout(r))
+      expect(uids.one.before).not.toBe(uids.two.before)
+      expect(uids.one.before).toBe(uids.one.after)
+      expect(uids.two.before).toBe(uids.two.after)
+    })
+
     test('should teardown in-scope effects', async () => {
       let resolve: (val?: any) => void
       const ready = new Promise(r => {
@@ -266,7 +329,10 @@ describe('SFC <script setup> helpers', () => {
 
       const Comp = defineComponent({
         async setup() {
-          await withAsyncContext(Promise.resolve())
+          let __temp: any, __restore: any
+          ;[__temp, __restore] = withAsyncContext(() => Promise.resolve())
+          __temp = await __temp
+          __restore()
 
           c = computed(() => {})
           // register the lifecycle after an await statement
index 5bbcffd9575b08e0c8961ae7694a832db1f6f3fe..4dcadbabdde9de9f7eac0607c29b4aaa85ca8358 100644 (file)
@@ -1,9 +1,9 @@
 import { isPromise } from '../../shared/src'
 import {
   getCurrentInstance,
+  setCurrentInstance,
   SetupContext,
-  createSetupContext,
-  setCurrentInstance
+  createSetupContext
 } from './component'
 import { EmitFn, EmitsOptions } from './componentEmits'
 import {
@@ -230,25 +230,32 @@ export function mergeDefaults(
 }
 
 /**
- * Runtime helper for storing and resuming current instance context in
- * async setup().
+ * `<script setup>` helper for persisting the current instance context over
+ * async/await flows.
+ *
+ * `@vue/compiler-sfc` converts the following:
+ *
+ * ```ts
+ * const x = await foo()
+ * ```
+ *
+ * into:
+ *
+ * ```ts
+ * let __temp, __restore
+ * const x = (([__temp, __restore] = withAsyncContext(() => foo())),__temp=await __temp,__restore(),__temp)
+ * ```
+ * @internal
  */
-export function withAsyncContext<T>(awaitable: T | Promise<T>): Promise<T> {
+export function withAsyncContext(getAwaitable: () => any) {
   const ctx = getCurrentInstance()
-  setCurrentInstance(null) // unset after storing instance
-  if (__DEV__ && !ctx) {
-    warn(`withAsyncContext() called when there is no active context instance.`)
+  let awaitable = getAwaitable()
+  setCurrentInstance(null)
+  if (isPromise(awaitable)) {
+    awaitable = awaitable.catch(e => {
+      setCurrentInstance(ctx)
+      throw e
+    })
   }
-  return isPromise<T>(awaitable)
-    ? awaitable.then(
-        res => {
-          setCurrentInstance(ctx)
-          return res
-        },
-        err => {
-          setCurrentInstance(ctx)
-          throw err
-        }
-      )
-    : (awaitable as any)
+  return [awaitable, () => setCurrentInstance(ctx)]
 }