¿Qué es CIFAR-10?

CIFAR-10 es un conjunto de datos de imágenes pequeñas (32x32 píxeles, 3 canales) organizadas en 10 clases: avión, automóvil, pájaro, gato, ciervo, perro, rana, caballo, barco y camión.

Preparación de los datos


# Cargar el dataset CIFAR-10 desde Keras
from tensorflow.keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalizar los valores de píxeles (0 a 255) a un rango entre 0 y 1
x_train = x_train / 255.0
x_test = x_test / 255.0

Se cargan los datos y se escalan para que la red neuronal los procese más eficientemente.

Arquitectura del modelo


from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

# Crear un modelo secuencial con varias capas convolucionales y densas
model = Sequential([
# Capa convolucional con 32 filtros, tamaño 3x3 y función de activación ReLU
Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
MaxPooling2D(2, 2),  # Submuestreo para reducir la dimensión

# Segunda capa convolucional
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(2, 2),

# Aplanar la salida para pasarla a capas densas
Flatten(),
Dense(64, activation='relu'),     # Capa oculta
Dense(10, activation='softmax')   # Capa de salida para clasificación multiclase
])

Modelo CNN típico para clasificación de imágenes, ideal para reconocer patrones espaciales.

Entrenamiento del modelo


# Compilar el modelo: función de pérdida y optimizador
model.compile(optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])

# Entrenar el modelo con 10 épocas y 20% de validación
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)

Se utiliza el optimizador Adam y una función de pérdida adecuada para clasificación multiclase con etiquetas enteras.

Evaluación del modelo


import matplotlib.pyplot as plt

# Graficar precisión durante el entrenamiento
plt.plot(history.history['accuracy'], label='Entrenamiento')
plt.plot(history.history['val_accuracy'], label='Validación')
plt.xlabel('Época')
plt.ylabel('Precisión')
plt.legend()
plt.title('Precisión del modelo')
plt.show()

Visualizamos cómo evoluciona la precisión en el entrenamiento y validación para detectar overfitting.

Predicciones


import numpy as np
import matplotlib.pyplot as plt
import random

# Obtener las predicciones del modelo para el set de prueba
predictions = model.predict(x_test)
predicted_labels = np.argmax(predictions, axis=1)

# Las clases del dataset CIFAR-10
cifar10_classes = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]
                

Se obtiene la clase más probable para cada imagen del conjunto de prueba.

Ejemplo de predicción


# Mostrar 5 imágenes aleatorias del set de prueba con sus predicciones
for i in range(5):
    random_index = random.randint(0, len(x_test) - 1)
    plt.figure()  # Crea una nueva figura para cada imagen
    plt.imshow(x_test[random_index])
    plt.title(f"Etiqueta real: {cifar10_classes[y_test[random_index][0]]} - Predicción: {cifar10_classes[predicted_labels[random_index]]}")
    plt.axis('off')
    plt.show()

Visualizamos una predicción del modelo para comprobar su funcionamiento en la práctica.