File size: 5,704 Bytes
f01c9d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class LocalClassifier {
  constructor() {
    this.weights = new Map(); // tag -> weight vector
    this.biases = new Map(); // tag -> bias
    this.learningRate = 0.01;
    this.featureDim = 512; // CLAP embedding dimension
    this.isInitialized = false;
  }

  initialize(featureDim = 512) {
    this.featureDim = featureDim;
    this.isInitialized = true;
  }

  // Simple logistic regression training
  trainOnFeedback(features, tag, feedback) {
    if (!this.isInitialized) {
      this.initialize();
    }

    // Convert feedback to target value
    let target;
    switch (feedback) {
      case 'positive':
        target = 1.0;
        break;
      case 'negative':
        target = 0.0;
        break;
      case 'custom':
        target = 1.0;
        break;
      default:
        return; // Skip unknown feedback
    }

    // Initialize weights for new tag
    if (!this.weights.has(tag)) {
      this.weights.set(tag, new Array(this.featureDim).fill(0).map(() => 
        (Math.random() - 0.5) * 0.01
      ));
      this.biases.set(tag, 0);
    }

    const weights = this.weights.get(tag);
    const bias = this.biases.get(tag);

    // Forward pass
    let logit = bias;
    for (let i = 0; i < features.length; i++) {
      logit += weights[i] * features[i];
    }

    // Sigmoid activation
    const prediction = 1 / (1 + Math.exp(-logit));

    // Compute gradient
    const error = prediction - target;
    
    // Update weights and bias
    for (let i = 0; i < features.length; i++) {
      weights[i] -= this.learningRate * error * features[i];
    }
    this.biases.set(tag, bias - this.learningRate * error);

    // Store updated weights
    this.weights.set(tag, weights);
  }

  // Predict confidence for a tag given features
  predict(features, tag) {
    if (!this.weights.has(tag)) {
      return null; // No training data for this tag
    }

    const weights = this.weights.get(tag);
    const bias = this.biases.get(tag);

    let logit = bias;
    for (let i = 0; i < Math.min(features.length, weights.length); i++) {
      logit += weights[i] * features[i];
    }

    // Sigmoid activation
    return 1 / (1 + Math.exp(-logit));
  }

  // Get all predictions for given features
  predictAll(features, candidateTags) {
    const predictions = [];
    
    for (const tag of candidateTags) {
      const confidence = this.predict(features, tag);
      if (confidence !== null) {
        predictions.push({ tag, confidence });
      }
    }

    return predictions.sort((a, b) => b.confidence - a.confidence);
  }

  // Retrain on batch of feedback data
  retrainOnBatch(feedbackData) {
    for (const item of feedbackData) {
      if (item.audioFeatures && item.correctedTags) {
        // Create simple features from audio metadata
        const features = this.extractSimpleFeatures(item.audioFeatures);
        
        // Train on corrected tags
        for (const tagData of item.correctedTags) {
          this.trainOnFeedback(features, tagData.tag, tagData.feedback);
        }
      }
    }
  }

  // Extract simple features from audio metadata
  extractSimpleFeatures(audioFeatures) {
    // Create a simple feature vector from audio metadata
    // In a real implementation, this would use actual CLAP embeddings
    const features = new Array(this.featureDim).fill(0);
    
    if (audioFeatures) {
      // Use basic audio properties to create pseudo-features
      features[0] = audioFeatures.duration / 60; // Duration in minutes
      features[1] = audioFeatures.sampleRate / 48000; // Normalized sample rate
      features[2] = audioFeatures.numberOfChannels; // Number of channels
      
      // Fill remaining with small random values based on hash of properties
      const seed = this.simpleHash(JSON.stringify(audioFeatures));
      for (let i = 3; i < this.featureDim; i++) {
        features[i] = this.seededRandom(seed + i) * 0.1;
      }
    }
    
    return features;
  }

  // Simple hash function for seeded random
  simpleHash(str) {
    let hash = 0;
    for (let i = 0; i < str.length; i++) {
      const char = str.charCodeAt(i);
      hash = ((hash << 5) - hash) + char;
      hash = hash & hash; // Convert to 32-bit integer
    }
    return Math.abs(hash);
  }

  // Seeded random number generator
  seededRandom(seed) {
    const x = Math.sin(seed) * 10000;
    return x - Math.floor(x);
  }

  // Save model to localStorage
  saveModel() {
    const modelData = {
      weights: Object.fromEntries(this.weights),
      biases: Object.fromEntries(this.biases),
      featureDim: this.featureDim,
      learningRate: this.learningRate
    };
    
    localStorage.setItem('clipTaggerModel', JSON.stringify(modelData));
  }

  // Load model from localStorage
  loadModel() {
    const saved = localStorage.getItem('clipTaggerModel');
    if (saved) {
      try {
        const modelData = JSON.parse(saved);
        this.weights = new Map(Object.entries(modelData.weights));
        this.biases = new Map(Object.entries(modelData.biases));
        this.featureDim = modelData.featureDim || 512;
        this.learningRate = modelData.learningRate || 0.01;
        this.isInitialized = true;
        return true;
      } catch (error) {
        console.error('Error loading model:', error);
      }
    }
    return false;
  }

  // Get model statistics
  getModelStats() {
    return {
      trainedTags: this.weights.size,
      featureDim: this.featureDim,
      learningRate: this.learningRate,
      tags: Array.from(this.weights.keys())
    };
  }

  // Clear the model
  clearModel() {
    this.weights.clear();
    this.biases.clear();
    localStorage.removeItem('clipTaggerModel');
  }
}

export default LocalClassifier;