]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
ALSA: timer: Don't take register_mutex with copy_from/to_user()
authorTakashi Iwai <tiwai@suse.de>
Fri, 21 Mar 2025 17:26:52 +0000 (18:26 +0100)
committerTakashi Iwai <tiwai@suse.de>
Fri, 21 Mar 2025 17:28:28 +0000 (18:28 +0100)
The infamous mmap_lock taken in copy_from/to_user() can be often
problematic when it's called inside another mutex, as they might lead
to deadlocks.

In the case of ALSA timer code, the bad pattern is with
guard(mutex)(&register_mutex) that covers copy_from/to_user() -- which
was mistakenly introduced at converting to guard(), and it had been
carefully worked around in the past.

This patch fixes those pieces simply by moving copy_from/to_user() out
of the register mutex lock again.

Fixes: 3923de04c817 ("ALSA: pcm: oss: Use guard() for setup")
Reported-by: syzbot+2b96f44164236dda0f3b@syzkaller.appspotmail.com
Closes: https://lore.kernel.org/67dd86c8.050a0220.25ae54.0059.GAE@google.com
Link: https://patch.msgid.link/20250321172653.14310-1-tiwai@suse.de
Signed-off-by: Takashi Iwai <tiwai@suse.de>
sound/core/timer.c

index fbada79380f9eaeaa2bb71d84acb4a90a374c5f7..d774b9b71ce2382400442d488d35edd93e6e9501 100644 (file)
@@ -1515,91 +1515,97 @@ static void snd_timer_user_copy_id(struct snd_timer_id *id, struct snd_timer *ti
        id->subdevice = timer->tmr_subdevice;
 }
 
-static int snd_timer_user_next_device(struct snd_timer_id __user *_tid)
+static void get_next_device(struct snd_timer_id *id)
 {
-       struct snd_timer_id id;
        struct snd_timer *timer;
        struct list_head *p;
 
-       if (copy_from_user(&id, _tid, sizeof(id)))
-               return -EFAULT;
-       guard(mutex)(&register_mutex);
-       if (id.dev_class < 0) {         /* first item */
+       if (id->dev_class < 0) {                /* first item */
                if (list_empty(&snd_timer_list))
-                       snd_timer_user_zero_id(&id);
+                       snd_timer_user_zero_id(id);
                else {
                        timer = list_entry(snd_timer_list.next,
                                           struct snd_timer, device_list);
-                       snd_timer_user_copy_id(&id, timer);
+                       snd_timer_user_copy_id(id, timer);
                }
        } else {
-               switch (id.dev_class) {
+               switch (id->dev_class) {
                case SNDRV_TIMER_CLASS_GLOBAL:
-                       id.device = id.device < 0 ? 0 : id.device + 1;
+                       id->device = id->device < 0 ? 0 : id->device + 1;
                        list_for_each(p, &snd_timer_list) {
                                timer = list_entry(p, struct snd_timer, device_list);
                                if (timer->tmr_class > SNDRV_TIMER_CLASS_GLOBAL) {
-                                       snd_timer_user_copy_id(&id, timer);
+                                       snd_timer_user_copy_id(id, timer);
                                        break;
                                }
-                               if (timer->tmr_device >= id.device) {
-                                       snd_timer_user_copy_id(&id, timer);
+                               if (timer->tmr_device >= id->device) {
+                                       snd_timer_user_copy_id(id, timer);
                                        break;
                                }
                        }
                        if (p == &snd_timer_list)
-                               snd_timer_user_zero_id(&id);
+                               snd_timer_user_zero_id(id);
                        break;
                case SNDRV_TIMER_CLASS_CARD:
                case SNDRV_TIMER_CLASS_PCM:
-                       if (id.card < 0) {
-                               id.card = 0;
+                       if (id->card < 0) {
+                               id->card = 0;
                        } else {
-                               if (id.device < 0) {
-                                       id.device = 0;
+                               if (id->device < 0) {
+                                       id->device = 0;
                                } else {
-                                       if (id.subdevice < 0)
-                                               id.subdevice = 0;
-                                       else if (id.subdevice < INT_MAX)
-                                               id.subdevice++;
+                                       if (id->subdevice < 0)
+                                               id->subdevice = 0;
+                                       else if (id->subdevice < INT_MAX)
+                                               id->subdevice++;
                                }
                        }
                        list_for_each(p, &snd_timer_list) {
                                timer = list_entry(p, struct snd_timer, device_list);
-                               if (timer->tmr_class > id.dev_class) {
-                                       snd_timer_user_copy_id(&id, timer);
+                               if (timer->tmr_class > id->dev_class) {
+                                       snd_timer_user_copy_id(id, timer);
                                        break;
                                }
-                               if (timer->tmr_class < id.dev_class)
+                               if (timer->tmr_class < id->dev_class)
                                        continue;
-                               if (timer->card->number > id.card) {
-                                       snd_timer_user_copy_id(&id, timer);
+                               if (timer->card->number > id->card) {
+                                       snd_timer_user_copy_id(id, timer);
                                        break;
                                }
-                               if (timer->card->number < id.card)
+                               if (timer->card->number < id->card)
                                        continue;
-                               if (timer->tmr_device > id.device) {
-                                       snd_timer_user_copy_id(&id, timer);
+                               if (timer->tmr_device > id->device) {
+                                       snd_timer_user_copy_id(id, timer);
                                        break;
                                }
-                               if (timer->tmr_device < id.device)
+                               if (timer->tmr_device < id->device)
                                        continue;
-                               if (timer->tmr_subdevice > id.subdevice) {
-                                       snd_timer_user_copy_id(&id, timer);
+                               if (timer->tmr_subdevice > id->subdevice) {
+                                       snd_timer_user_copy_id(id, timer);
                                        break;
                                }
-                               if (timer->tmr_subdevice < id.subdevice)
+                               if (timer->tmr_subdevice < id->subdevice)
                                        continue;
-                               snd_timer_user_copy_id(&id, timer);
+                               snd_timer_user_copy_id(id, timer);
                                break;
                        }
                        if (p == &snd_timer_list)
-                               snd_timer_user_zero_id(&id);
+                               snd_timer_user_zero_id(id);
                        break;
                default:
-                       snd_timer_user_zero_id(&id);
+                       snd_timer_user_zero_id(id);
                }
        }
+}
+
+static int snd_timer_user_next_device(struct snd_timer_id __user *_tid)
+{
+       struct snd_timer_id id;
+
+       if (copy_from_user(&id, _tid, sizeof(id)))
+               return -EFAULT;
+       scoped_guard(mutex, &register_mutex)
+               get_next_device(&id);
        if (copy_to_user(_tid, &id, sizeof(*_tid)))
                return -EFAULT;
        return 0;
@@ -1620,23 +1626,24 @@ static int snd_timer_user_ginfo(struct file *file,
        tid = ginfo->tid;
        memset(ginfo, 0, sizeof(*ginfo));
        ginfo->tid = tid;
-       guard(mutex)(&register_mutex);
-       t = snd_timer_find(&tid);
-       if (!t)
-               return -ENODEV;
-       ginfo->card = t->card ? t->card->number : -1;
-       if (t->hw.flags & SNDRV_TIMER_HW_SLAVE)
-               ginfo->flags |= SNDRV_TIMER_FLG_SLAVE;
-       strscpy(ginfo->id, t->id, sizeof(ginfo->id));
-       strscpy(ginfo->name, t->name, sizeof(ginfo->name));
-       scoped_guard(spinlock_irq, &t->lock)
-               ginfo->resolution = snd_timer_hw_resolution(t);
-       if (t->hw.resolution_min > 0) {
-               ginfo->resolution_min = t->hw.resolution_min;
-               ginfo->resolution_max = t->hw.resolution_max;
-       }
-       list_for_each(p, &t->open_list_head) {
-               ginfo->clients++;
+       scoped_guard(mutex, &register_mutex) {
+               t = snd_timer_find(&tid);
+               if (!t)
+                       return -ENODEV;
+               ginfo->card = t->card ? t->card->number : -1;
+               if (t->hw.flags & SNDRV_TIMER_HW_SLAVE)
+                       ginfo->flags |= SNDRV_TIMER_FLG_SLAVE;
+               strscpy(ginfo->id, t->id, sizeof(ginfo->id));
+               strscpy(ginfo->name, t->name, sizeof(ginfo->name));
+               scoped_guard(spinlock_irq, &t->lock)
+                       ginfo->resolution = snd_timer_hw_resolution(t);
+               if (t->hw.resolution_min > 0) {
+                       ginfo->resolution_min = t->hw.resolution_min;
+                       ginfo->resolution_max = t->hw.resolution_max;
+               }
+               list_for_each(p, &t->open_list_head) {
+                       ginfo->clients++;
+               }
        }
        if (copy_to_user(_ginfo, ginfo, sizeof(*ginfo)))
                return -EFAULT;
@@ -1674,31 +1681,31 @@ static int snd_timer_user_gstatus(struct file *file,
        struct snd_timer_gstatus gstatus;
        struct snd_timer_id tid;
        struct snd_timer *t;
-       int err = 0;
 
        if (copy_from_user(&gstatus, _gstatus, sizeof(gstatus)))
                return -EFAULT;
        tid = gstatus.tid;
        memset(&gstatus, 0, sizeof(gstatus));
        gstatus.tid = tid;
-       guard(mutex)(&register_mutex);
-       t = snd_timer_find(&tid);
-       if (t != NULL) {
-               guard(spinlock_irq)(&t->lock);
-               gstatus.resolution = snd_timer_hw_resolution(t);
-               if (t->hw.precise_resolution) {
-                       t->hw.precise_resolution(t, &gstatus.resolution_num,
-                                                &gstatus.resolution_den);
+       scoped_guard(mutex, &register_mutex) {
+               t = snd_timer_find(&tid);
+               if (t != NULL) {
+                       guard(spinlock_irq)(&t->lock);
+                       gstatus.resolution = snd_timer_hw_resolution(t);
+                       if (t->hw.precise_resolution) {
+                               t->hw.precise_resolution(t, &gstatus.resolution_num,
+                                                        &gstatus.resolution_den);
+                       } else {
+                               gstatus.resolution_num = gstatus.resolution;
+                               gstatus.resolution_den = 1000000000uL;
+                       }
                } else {
-                       gstatus.resolution_num = gstatus.resolution;
-                       gstatus.resolution_den = 1000000000uL;
+                       return -ENODEV;
                }
-       } else {
-               err = -ENODEV;
        }
-       if (err >= 0 && copy_to_user(_gstatus, &gstatus, sizeof(gstatus)))
-               err = -EFAULT;
-       return err;
+       if (copy_to_user(_gstatus, &gstatus, sizeof(gstatus)))
+               return -EFAULT;
+       return 0;
 }
 
 static int snd_timer_user_tselect(struct file *file,