Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Epsilon617
		
	commited on
		
		
					Commit 
							
							·
						
						92cd759
	
1
								Parent(s):
							
							c2c7513
								
add genre prediction head
Browse files
    	
        Prediction_Head/MTGGenre_head.py
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            class MLPProberBase(nn.Module):
         | 
| 6 | 
            +
                def __init__(self, d=768, num_outputs=87):
         | 
| 7 | 
            +
                    super().__init__()
         | 
| 8 | 
            +
                    self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes)
         | 
| 9 | 
            +
                    self.num_layers = len(self.hidden_layer_sizes)
         | 
| 10 | 
            +
                    for i, ld in enumerate(self.hidden_layer_sizes):
         | 
| 11 | 
            +
                        setattr(self, f"hidden_{i}", nn.Linear(d, ld))
         | 
| 12 | 
            +
                        d = ld
         | 
| 13 | 
            +
                    self.output = nn.Linear(d, num_outputs)
         | 
| 14 | 
            +
                
         | 
| 15 | 
            +
                def forward(self, x):
         | 
| 16 | 
            +
                    for i in range(self.num_layers):
         | 
| 17 | 
            +
                        x = getattr(self, f"hidden_{i}")(x)
         | 
| 18 | 
            +
                        # x = self.dropout(x)
         | 
| 19 | 
            +
                        x = F.relu(x)
         | 
| 20 | 
            +
                    output = self.output(x)
         | 
| 21 | 
            +
                    return output
         | 
    	
        Prediction_Head/MTGGenre_id2class.json
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            {"0": "genre---rock", "1": "genre---pop", "2": "genre---classical", "3": "genre---popfolk", "4": "genre---disco", "5": "genre---funk", "6": "genre---rnb", "7": "genre---ambient", "8": "genre---chillout", "9": "genre---downtempo", "10": "genre---easylistening", "11": "genre---electronic", "12": "genre---lounge", "13": "genre---triphop", "14": "genre---breakbeat", "15": "genre---techno", "16": "genre---newage", "17": "genre---jazz", "18": "genre---metal", "19": "genre---industrial", "20": "genre---instrumentalrock", "21": "genre---minimal", "22": "genre---alternative", "23": "genre---experimental", "24": "genre---drumnbass", "25": "genre---soul", "26": "genre---fusion", "27": "genre---soundtrack", "28": "genre---electropop", "29": "genre---world", "30": "genre---ethno", "31": "genre---trance", "32": "genre---orchestral", "33": "genre---grunge", "34": "genre---chanson", "35": "genre---worldfusion", "36": "genre---hiphop", "37": "genre---groove", "38": "genre---instrumentalpop", "39": "genre---blues", "40": "genre---reggae", "41": "genre---dance", "42": "genre---club", "43": "genre---punkrock", "44": "genre---folk", "45": "genre---synthpop", "46": "genre---poprock", "47": "genre---choir", "48": "genre---symphonic", "49": "genre---indie", "50": "genre---progressive", "51": "genre---acidjazz", "52": "genre---contemporary", "53": "genre---newwave", "54": "genre---dub", "55": "genre---rocknroll", "56": "genre---hard", "57": "genre---hardrock", "58": "genre---house", "59": "genre---atmospheric", "60": "genre---psychedelic", "61": "genre---improvisation", "62": "genre---country", "63": "genre---electronica", "64": "genre---rap", "65": "genre---60s", "66": "genre---70s", "67": "genre---darkambient", "68": "genre---idm", "69": "genre---latin", "70": "genre---postrock", "71": "genre---bossanova", "72": "genre---singersongwriter", "73": "genre---darkwave", "74": "genre---swing", "75": "genre---medieval", "76": "genre---celtic", "77": "genre---eurodance", "78": "genre---classicrock", "79": "genre---dubstep", "80": "genre---bluesrock", "81": "genre---edm", "82": "genre---deephouse", "83": "genre---jazzfusion", "84": "genre---alternativerock", "85": "genre---80s", "86": "genre---90s"}
         | 
    	
        Prediction_Head/__pycache__/MTGGenre_head.cpython-310.pyc
    ADDED
    
    | Binary file (1.08 kB). View file | 
|  | 
    	
        Prediction_Head/best_MTGGenre.ckpt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:83b7dcffde10a0dc7ba74341ea56dabec5c5de7cad6a0483708c80f1d893514a
         | 
| 3 | 
            +
            size 1759067
         | 
    	
        __pycache__/app.cpython-310.pyc
    CHANGED
    
    | Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ | 
|  | 
    	
        app.py
    CHANGED
    
    | @@ -8,9 +8,12 @@ import torchaudio | |
| 8 | 
             
            import torchaudio.transforms as T
         | 
| 9 | 
             
            import logging
         | 
| 10 |  | 
|  | |
|  | |
| 11 | 
             
            import importlib 
         | 
| 12 | 
             
            modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT")
         | 
| 13 |  | 
|  | |
| 14 | 
             
            # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
         | 
| 15 |  | 
| 16 |  | 
| @@ -34,7 +37,7 @@ live_inputs = [ | |
| 34 | 
             
            ]
         | 
| 35 | 
             
            # outputs = [gr.components.Textbox()]
         | 
| 36 | 
             
            # outputs = [gr.components.Textbox(), transcription_df]
         | 
| 37 | 
            -
            title = " | 
| 38 | 
             
            description = "An example of using MERT-95M-public to conduct music tagging."
         | 
| 39 | 
             
            article = ""
         | 
| 40 | 
             
            audio_examples = [
         | 
| @@ -48,9 +51,17 @@ audio_examples = [ | |
| 48 | 
             
            model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public")
         | 
| 49 | 
             
            processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public")
         | 
| 50 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 51 |  | 
| 52 | 
             
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 53 | 
             
            model.to(device)
         | 
|  | |
| 54 |  | 
| 55 | 
             
            def convert_audio(inputs, microphone):
         | 
| 56 | 
             
                if (microphone is not None):
         | 
| @@ -75,10 +86,17 @@ def convert_audio(inputs, microphone): | |
| 75 | 
             
                # take a look at the output shape, there are 13 layers of representation
         | 
| 76 | 
             
                # each layer performs differently in different downstream tasks, you should choose empirically
         | 
| 77 | 
             
                all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
         | 
| 78 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 79 | 
             
                # logger.warning(all_layer_hidden_states.shape)
         | 
| 80 |  | 
| 81 | 
            -
                return f"device {device} | 
|  | |
| 82 |  | 
| 83 | 
             
            def live_convert_audio(microphone):
         | 
| 84 | 
             
                if (microphone is not None):
         | 
| @@ -103,10 +121,17 @@ def live_convert_audio(microphone): | |
| 103 | 
             
                # take a look at the output shape, there are 13 layers of representation
         | 
| 104 | 
             
                # each layer performs differently in different downstream tasks, you should choose empirically
         | 
| 105 | 
             
                all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
         | 
| 106 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 107 | 
             
                # logger.warning(all_layer_hidden_states.shape)
         | 
| 108 |  | 
| 109 | 
            -
                return f"device {device}, sample reprensentation:  {str(all_layer_hidden_states[12, 0, :10])}"
         | 
|  | |
| 110 |  | 
| 111 |  | 
| 112 | 
             
            audio_chunked = gr.Interface(
         | 
|  | |
| 8 | 
             
            import torchaudio.transforms as T
         | 
| 9 | 
             
            import logging
         | 
| 10 |  | 
| 11 | 
            +
            import json
         | 
| 12 | 
            +
             | 
| 13 | 
             
            import importlib 
         | 
| 14 | 
             
            modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT")
         | 
| 15 |  | 
| 16 | 
            +
            from Prediction_Head.MTGGenre_head import MLPProberBase 
         | 
| 17 | 
             
            # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
         | 
| 18 |  | 
| 19 |  | 
|  | |
| 37 | 
             
            ]
         | 
| 38 | 
             
            # outputs = [gr.components.Textbox()]
         | 
| 39 | 
             
            # outputs = [gr.components.Textbox(), transcription_df]
         | 
| 40 | 
            +
            title = "Predict the top 5 possible genres of Music"
         | 
| 41 | 
             
            description = "An example of using MERT-95M-public to conduct music tagging."
         | 
| 42 | 
             
            article = ""
         | 
| 43 | 
             
            audio_examples = [
         | 
|  | |
| 51 | 
             
            model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public")
         | 
| 52 | 
             
            processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public")
         | 
| 53 |  | 
| 54 | 
            +
            MERT_LAYER_IDX = 7
         | 
| 55 | 
            +
            MTGGenre_classifier = MLPProberBase()
         | 
| 56 | 
            +
            MTGGenre_classifier.load_state_dict(torch.load('Prediction_Head/best_MTGGenre.ckpt')['state_dict'])
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            with open('Prediction_Head/MTGGenre_id2class.json', 'r') as f:
         | 
| 59 | 
            +
               id2cls=json.load(f)
         | 
| 60 | 
            +
             | 
| 61 |  | 
| 62 | 
             
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 63 | 
             
            model.to(device)
         | 
| 64 | 
            +
            MTGGenre_classifier.to(device)
         | 
| 65 |  | 
| 66 | 
             
            def convert_audio(inputs, microphone):
         | 
| 67 | 
             
                if (microphone is not None):
         | 
|  | |
| 86 | 
             
                # take a look at the output shape, there are 13 layers of representation
         | 
| 87 | 
             
                # each layer performs differently in different downstream tasks, you should choose empirically
         | 
| 88 | 
             
                all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
         | 
| 89 | 
            +
                print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87]
         | 
| 92 | 
            +
                print(logits.shape)
         | 
| 93 | 
            +
                sorted_idx = torch.argsort(logits, dim = -1, descending=True)
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
                output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
         | 
| 96 | 
             
                # logger.warning(all_layer_hidden_states.shape)
         | 
| 97 |  | 
| 98 | 
            +
                # return f"device {device}, sample reprensentation:  {str(all_layer_hidden_states[12, 0, :10])}"
         | 
| 99 | 
            +
                return f"device: {device}\n" + output_texts
         | 
| 100 |  | 
| 101 | 
             
            def live_convert_audio(microphone):
         | 
| 102 | 
             
                if (microphone is not None):
         | 
|  | |
| 121 | 
             
                # take a look at the output shape, there are 13 layers of representation
         | 
| 122 | 
             
                # each layer performs differently in different downstream tasks, you should choose empirically
         | 
| 123 | 
             
                all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
         | 
| 124 | 
            +
                print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87]
         | 
| 127 | 
            +
                print(logits.shape)
         | 
| 128 | 
            +
                sorted_idx = torch.argsort(logits, dim = -1, descending=True)
         | 
| 129 | 
            +
                
         | 
| 130 | 
            +
                output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
         | 
| 131 | 
             
                # logger.warning(all_layer_hidden_states.shape)
         | 
| 132 |  | 
| 133 | 
            +
                # return f"device {device}, sample reprensentation:  {str(all_layer_hidden_states[12, 0, :10])}"
         | 
| 134 | 
            +
                return f"device: {device}\n" + output_texts
         | 
| 135 |  | 
| 136 |  | 
| 137 | 
             
            audio_chunked = gr.Interface(
         | 
