diff --git a/libraries/exoplayer_dash/src/main/java/androidx/media3/exoplayer/dash/DashMediaPeriod.java b/libraries/exoplayer_dash/src/main/java/androidx/media3/exoplayer/dash/DashMediaPeriod.java index 5ad8c87734a..0657539e95f 100644 --- a/libraries/exoplayer_dash/src/main/java/androidx/media3/exoplayer/dash/DashMediaPeriod.java +++ b/libraries/exoplayer_dash/src/main/java/androidx/media3/exoplayer/dash/DashMediaPeriod.java @@ -748,7 +748,7 @@ private static int identifyEmbeddedTracks( primaryGroupClosedCaptionTrackFormats[i] = getClosedCaptionTrackFormats(adaptationSets, groupedAdaptationSetIndices[i]); if (primaryGroupClosedCaptionTrackFormats[i].length != 0) { - numEmbeddedTrackGroups++; + numEmbeddedTrackGroups += primaryGroupClosedCaptionTrackFormats[i].length; } } return numEmbeddedTrackGroups; @@ -790,7 +790,7 @@ private static int buildPrimaryAndEmbeddedTrackGroupInfos( int eventMessageTrackGroupIndex = primaryGroupHasEventMessageTrackFlags[i] ? trackGroupCount++ : C.INDEX_UNSET; int closedCaptionTrackGroupIndex = - primaryGroupClosedCaptionTrackFormats[i].length != 0 ? trackGroupCount++ : C.INDEX_UNSET; + primaryGroupClosedCaptionTrackFormats[i].length != 0 ? trackGroupCount : C.INDEX_UNSET; maybeUpdateFormatsForParsedText(chunkSourceFactory, formats); trackGroups[primaryTrackGroupIndex] = new TrackGroup(trackGroupId, formats); @@ -800,7 +800,8 @@ private static int buildPrimaryAndEmbeddedTrackGroupInfos( adaptationSetIndices, primaryTrackGroupIndex, eventMessageTrackGroupIndex, - closedCaptionTrackGroupIndex); + closedCaptionTrackGroupIndex, + primaryGroupClosedCaptionTrackFormats[i].length); if (eventMessageTrackGroupIndex != C.INDEX_UNSET) { String eventMessageTrackGroupId = trackGroupId + ":emsg"; Format format = @@ -814,23 +815,34 @@ private static int buildPrimaryAndEmbeddedTrackGroupInfos( TrackGroupInfo.embeddedEmsgTrack(adaptationSetIndices, primaryTrackGroupIndex); } if (closedCaptionTrackGroupIndex != C.INDEX_UNSET) { - String closedCaptionTrackGroupId = trackGroupId + ":cc"; - trackGroupInfos[closedCaptionTrackGroupIndex] = - TrackGroupInfo.embeddedClosedCaptionTrack( - adaptationSetIndices, - primaryTrackGroupIndex, - ImmutableList.copyOf(primaryGroupClosedCaptionTrackFormats[i])); - maybeUpdateFormatsForParsedText( - chunkSourceFactory, primaryGroupClosedCaptionTrackFormats[i]); - for (int j = 0; j < primaryGroupClosedCaptionTrackFormats[i].length; j++) { - primaryGroupClosedCaptionTrackFormats[i][j] = - primaryGroupClosedCaptionTrackFormats[i][j] + Format[] closedCaptionFormats = primaryGroupClosedCaptionTrackFormats[i]; + maybeUpdateFormatsForParsedText(chunkSourceFactory, closedCaptionFormats); + + for (int currentCaptionIndex = 0; currentCaptionIndex < closedCaptionFormats.length; + currentCaptionIndex++) { + // primaryTrackGroupId should be updated first to end up in both + // trackGroupInfos and trackGroups. + primaryGroupClosedCaptionTrackFormats[i][currentCaptionIndex] = + closedCaptionFormats[currentCaptionIndex] .buildUpon() .setPrimaryTrackGroupId(trackGroupId) .build(); + + trackGroupInfos[closedCaptionTrackGroupIndex] = + TrackGroupInfo.embeddedClosedCaptionTrack( + adaptationSetIndices, + primaryTrackGroupIndex, + primaryGroupClosedCaptionTrackFormats[i][currentCaptionIndex]); + + String closedCaptionTrackGroupId = trackGroupId + ":cc:" + currentCaptionIndex; + trackGroups[closedCaptionTrackGroupIndex] = + new TrackGroup(closedCaptionTrackGroupId, + primaryGroupClosedCaptionTrackFormats[i][currentCaptionIndex]); + + closedCaptionTrackGroupIndex++; } - trackGroups[closedCaptionTrackGroupIndex] = - new TrackGroup(closedCaptionTrackGroupId, primaryGroupClosedCaptionTrackFormats[i]); + + trackGroupCount += closedCaptionFormats.length; } } return trackGroupCount; @@ -865,11 +877,18 @@ private ChunkSampleStream buildSampleStream( trackGroups.get(trackGroupInfo.embeddedEventMessageTrackGroupIndex); embeddedTrackCount++; } - ImmutableList embeddedClosedCaptionOriginalFormats = - trackGroupInfo.embeddedClosedCaptionTrackGroupIndex != C.INDEX_UNSET - ? trackGroupInfos[trackGroupInfo.embeddedClosedCaptionTrackGroupIndex] - .embeddedClosedCaptionTrackOriginalFormats - : ImmutableList.of(); + List embeddedClosedCaptionOriginalFormats = new ArrayList<>(); + if (trackGroupInfo.embeddedClosedCaptionTrackGroupStartIndex != C.INDEX_UNSET) { + for (int i = 0; i < trackGroupInfo.embeddedClosedCaptionTrackGroupLength; i++) { + Format closedCaptionsFormat = + trackGroupInfos[trackGroupInfo.embeddedClosedCaptionTrackGroupStartIndex + i] + .embeddedClosedCaptionTrackOriginalFormat; + + if (closedCaptionsFormat != null) { + embeddedClosedCaptionOriginalFormats.add(closedCaptionsFormat); + } + } + } embeddedTrackCount += embeddedClosedCaptionOriginalFormats.size(); Format[] embeddedTrackFormats = new Format[embeddedTrackCount]; @@ -1119,26 +1138,30 @@ private static final class TrackGroupInfo { public final int eventStreamGroupIndex; public final int primaryTrackGroupIndex; public final int embeddedEventMessageTrackGroupIndex; - public final int embeddedClosedCaptionTrackGroupIndex; + public final int embeddedClosedCaptionTrackGroupStartIndex; + public final int embeddedClosedCaptionTrackGroupLength; - /** Only non-empty for track groups representing embedded caption tracks. */ - public final ImmutableList embeddedClosedCaptionTrackOriginalFormats; + /** Only non-null for track groups representing embedded caption tracks. */ + @Nullable + public final Format embeddedClosedCaptionTrackOriginalFormat; public static TrackGroupInfo primaryTrack( int trackType, int[] adaptationSetIndices, int primaryTrackGroupIndex, int embeddedEventMessageTrackGroupIndex, - int embeddedClosedCaptionTrackGroupIndex) { + int embeddedClosedCaptionTrackGroupStartIndex, + int embeddedClosedCaptionTrackGroupLength) { return new TrackGroupInfo( trackType, CATEGORY_PRIMARY, adaptationSetIndices, primaryTrackGroupIndex, embeddedEventMessageTrackGroupIndex, - embeddedClosedCaptionTrackGroupIndex, + embeddedClosedCaptionTrackGroupStartIndex, + embeddedClosedCaptionTrackGroupLength, /* eventStreamGroupIndex= */ -1, - /* embeddedClosedCaptionTrackOriginalFormats= */ ImmutableList.of()); + /* embeddedClosedCaptionTrackOriginalFormat= */ null); } public static TrackGroupInfo embeddedEmsgTrack( @@ -1150,14 +1173,15 @@ public static TrackGroupInfo embeddedEmsgTrack( primaryTrackGroupIndex, C.INDEX_UNSET, C.INDEX_UNSET, + C.LENGTH_UNSET, /* eventStreamGroupIndex= */ -1, - /* embeddedClosedCaptionTrackOriginalFormats= */ ImmutableList.of()); + /* embeddedClosedCaptionTrackOriginalFormat= */ null); } public static TrackGroupInfo embeddedClosedCaptionTrack( int[] adaptationSetIndices, int primaryTrackGroupIndex, - ImmutableList originalFormats) { + Format originalFormat) { return new TrackGroupInfo( C.TRACK_TYPE_TEXT, CATEGORY_EMBEDDED, @@ -1165,8 +1189,9 @@ public static TrackGroupInfo embeddedClosedCaptionTrack( primaryTrackGroupIndex, C.INDEX_UNSET, C.INDEX_UNSET, + C.LENGTH_UNSET, /* eventStreamGroupIndex= */ -1, - originalFormats); + originalFormat); } public static TrackGroupInfo mpdEventTrack(int eventStreamIndex) { @@ -1177,8 +1202,9 @@ public static TrackGroupInfo mpdEventTrack(int eventStreamIndex) { /* primaryTrackGroupIndex= */ -1, C.INDEX_UNSET, C.INDEX_UNSET, + C.LENGTH_UNSET, eventStreamIndex, - /* embeddedClosedCaptionTrackOriginalFormats= */ ImmutableList.of()); + /* embeddedClosedCaptionTrackOriginalFormat= */ null); } private TrackGroupInfo( @@ -1187,17 +1213,19 @@ private TrackGroupInfo( int[] adaptationSetIndices, int primaryTrackGroupIndex, int embeddedEventMessageTrackGroupIndex, - int embeddedClosedCaptionTrackGroupIndex, + int embeddedClosedCaptionTrackGroupStartIndex, + int embeddedClosedCaptionTrackGroupLength, int eventStreamGroupIndex, - ImmutableList embeddedClosedCaptionTrackOriginalFormats) { + @Nullable Format embeddedClosedCaptionTrackOriginalFormat) { this.trackType = trackType; this.adaptationSetIndices = adaptationSetIndices; this.trackGroupCategory = trackGroupCategory; this.primaryTrackGroupIndex = primaryTrackGroupIndex; this.embeddedEventMessageTrackGroupIndex = embeddedEventMessageTrackGroupIndex; - this.embeddedClosedCaptionTrackGroupIndex = embeddedClosedCaptionTrackGroupIndex; + this.embeddedClosedCaptionTrackGroupStartIndex = embeddedClosedCaptionTrackGroupStartIndex; + this.embeddedClosedCaptionTrackGroupLength = embeddedClosedCaptionTrackGroupLength; this.eventStreamGroupIndex = eventStreamGroupIndex; - this.embeddedClosedCaptionTrackOriginalFormats = embeddedClosedCaptionTrackOriginalFormats; + this.embeddedClosedCaptionTrackOriginalFormat = embeddedClosedCaptionTrackOriginalFormat; } } } diff --git a/libraries/exoplayer_dash/src/test/java/androidx/media3/exoplayer/dash/DashMediaPeriodTest.java b/libraries/exoplayer_dash/src/test/java/androidx/media3/exoplayer/dash/DashMediaPeriodTest.java index c4a4bb06ce9..2bb3fc8c630 100644 --- a/libraries/exoplayer_dash/src/test/java/androidx/media3/exoplayer/dash/DashMediaPeriodTest.java +++ b/libraries/exoplayer_dash/src/test/java/androidx/media3/exoplayer/dash/DashMediaPeriodTest.java @@ -15,8 +15,13 @@ */ package androidx.media3.exoplayer.dash; +import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import android.net.Uri; @@ -34,7 +39,10 @@ import androidx.media3.exoplayer.source.CompositeSequenceableLoaderFactory; import androidx.media3.exoplayer.source.MediaSource.MediaPeriodId; import androidx.media3.exoplayer.source.MediaSourceEventListener; +import androidx.media3.exoplayer.source.SampleStream; import androidx.media3.exoplayer.source.TrackGroupArray; +import androidx.media3.exoplayer.trackselection.ExoTrackSelection; +import androidx.media3.exoplayer.trackselection.FixedTrackSelection; import androidx.media3.exoplayer.upstream.Allocator; import androidx.media3.exoplayer.upstream.LoadErrorHandlingPolicy; import androidx.media3.exoplayer.upstream.LoaderErrorThrower; @@ -47,6 +55,7 @@ import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; /** Unit tests for {@link DashMediaPeriod}. */ @RunWith(AndroidJUnit4.class) @@ -198,8 +207,8 @@ public void cea608AccessibilityDescriptor_createsCea608TrackGroup() throws IOExc DashMediaPeriod dashMediaPeriod = createDashMediaPeriod(manifest, 0); List adaptationSets = manifest.getPeriod(0).adaptationSets; - // We expect two adaptation sets. The first containing the video representations, and the second - // containing the embedded CEA-608 tracks. + // We expect three track groups. The first containing the video representations, + // and the other two containing the embedded CEA-608 tracks. Format.Builder cea608FormatBuilder = new Format.Builder().setSampleMimeType(MimeTypes.APPLICATION_CEA608); TrackGroupArray expectedTrackGroups = @@ -209,13 +218,15 @@ public void cea608AccessibilityDescriptor_createsCea608TrackGroup() throws IOExc adaptationSets.get(0).representations.get(0).format, adaptationSets.get(0).representations.get(1).format), new TrackGroup( - /* id= */ "123:cc", + /* id= */ "123:cc:0", cea608FormatBuilder .setId("123:cea608:1") .setLanguage("eng") .setAccessibilityChannel(1) .setPrimaryTrackGroupId("123") - .build(), + .build()), + new TrackGroup( + /* id= */ "123:cc:1", cea608FormatBuilder .setId("123:cea608:3") .setLanguage("deu") @@ -232,8 +243,8 @@ public void cea708AccessibilityDescriptor_createsCea708TrackGroup() throws IOExc DashMediaPeriod dashMediaPeriod = createDashMediaPeriod(manifest, 0); List adaptationSets = manifest.getPeriod(0).adaptationSets; - // We expect two adaptation sets. The first containing the video representations, and the second - // containing the embedded CEA-708 tracks. + // We expect three track groups. The first containing the video representations, + // and the other two containing the embedded CEA-708 tracks. Format.Builder cea608FormatBuilder = new Format.Builder().setSampleMimeType(MimeTypes.APPLICATION_CEA708); TrackGroupArray expectedTrackGroups = @@ -243,13 +254,15 @@ public void cea708AccessibilityDescriptor_createsCea708TrackGroup() throws IOExc adaptationSets.get(0).representations.get(0).format, adaptationSets.get(0).representations.get(1).format), new TrackGroup( - /* id= */ "123:cc", + /* id= */ "123:cc:0", cea608FormatBuilder .setId("123:cea708:1") .setLanguage("eng") .setAccessibilityChannel(1) .setPrimaryTrackGroupId("123") - .build(), + .build()), + new TrackGroup( + /* id= */ "123:cc:1", cea608FormatBuilder .setId("123:cea708:2") .setLanguage("deu") @@ -285,9 +298,50 @@ public void inbandEventStream_createsEmsgTrackGroups() throws IOException { MediaPeriodAsserts.assertTrackGroups(dashMediaPeriod, expectedTrackGroups); } + @Test + public void buildSampleStream_enclosesAllClosedCaptions() throws IOException { + DashManifest manifest = parseManifest("media/mpd/sample_mpd_cea_608_accessibility"); + DashChunkSource.Factory factory = mock(DashChunkSource.Factory.class); + DashMediaPeriod dashMediaPeriod = createDashMediaPeriod(manifest, factory, 0); + + ExoTrackSelection primaryTrackSelection = + new FixedTrackSelection(dashMediaPeriod.getTrackGroups().get(0), /* track= */ 0); + + List closedCaptionsFormats = List.of( + dashMediaPeriod.getTrackGroups().get(1).getFormat(0), + dashMediaPeriod.getTrackGroups().get(2).getFormat(0)); + + SampleStream[] streams = new SampleStream[4]; + dashMediaPeriod.selectTracks( + new ExoTrackSelection[]{ primaryTrackSelection }, + new boolean[4], + streams, + new boolean[4], + /* positionUs= */ 0L + ); + + @SuppressWarnings("unchecked") + ArgumentCaptor> closedCaptionFormatsCaptor = + ArgumentCaptor.forClass(List.class); + + verify(factory) + .createDashChunkSource(any(), any(), any(), anyInt(), any(), any(), anyInt(), anyLong(), + anyBoolean(), closedCaptionFormatsCaptor.capture(), any(), any(), any(), any()); + + Format[] actualFormats = closedCaptionFormatsCaptor.getValue().toArray(new Format[0]); + assertThat(actualFormats) + .asList() + .containsExactlyElementsIn(closedCaptionsFormats) + .inOrder(); + } + private static DashMediaPeriod createDashMediaPeriod(DashManifest manifest, int periodIndex) { + return createDashMediaPeriod(manifest, mock(DashChunkSource.Factory.class), periodIndex); + } + + private static DashMediaPeriod createDashMediaPeriod(DashManifest manifest, + DashChunkSource.Factory chunkSourceFactory, int periodIndex) { MediaPeriodId mediaPeriodId = new MediaPeriodId(/* periodUid= */ new Object()); - DashChunkSource.Factory chunkSourceFactory = mock(DashChunkSource.Factory.class); when(chunkSourceFactory.getOutputTextFormat(any())) .then(invocation -> invocation.getArguments()[0]); return new DashMediaPeriod(