CIFAR-10 es un conjunto de datos de imágenes pequeñas (32x32 píxeles, 3 canales) organizadas en 10 clases: avión, automóvil, pájaro, gato, ciervo, perro, rana, caballo, barco y camión.
# Cargar el dataset CIFAR-10 desde Keras
from tensorflow.keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# Normalizar los valores de píxeles (0 a 255) a un rango entre 0 y 1
x_train = x_train / 255.0
x_test = x_test / 255.0
Se cargan los datos y se escalan para que la red neuronal los procese más eficientemente.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# Crear un modelo secuencial con varias capas convolucionales y densas
model = Sequential([
# Capa convolucional con 32 filtros, tamaño 3x3 y función de activación ReLU
Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
MaxPooling2D(2, 2), # Submuestreo para reducir la dimensión
# Segunda capa convolucional
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
# Aplanar la salida para pasarla a capas densas
Flatten(),
Dense(64, activation='relu'), # Capa oculta
Dense(10, activation='softmax') # Capa de salida para clasificación multiclase
])
Modelo CNN típico para clasificación de imágenes, ideal para reconocer patrones espaciales.
# Compilar el modelo: función de pérdida y optimizador
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Entrenar el modelo con 10 épocas y 20% de validación
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)
Se utiliza el optimizador Adam y una función de pérdida adecuada para clasificación multiclase con etiquetas enteras.
import matplotlib.pyplot as plt
# Graficar precisión durante el entrenamiento
plt.plot(history.history['accuracy'], label='Entrenamiento')
plt.plot(history.history['val_accuracy'], label='Validación')
plt.xlabel('Época')
plt.ylabel('Precisión')
plt.legend()
plt.title('Precisión del modelo')
plt.show()
Visualizamos cómo evoluciona la precisión en el entrenamiento y validación para detectar overfitting.
import numpy as np
import matplotlib.pyplot as plt
import random
# Obtener las predicciones del modelo para el set de prueba
predictions = model.predict(x_test)
predicted_labels = np.argmax(predictions, axis=1)
# Las clases del dataset CIFAR-10
cifar10_classes = [
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
]
Se obtiene la clase más probable para cada imagen del conjunto de prueba.
# Mostrar 5 imágenes aleatorias del set de prueba con sus predicciones
for i in range(5):
random_index = random.randint(0, len(x_test) - 1)
plt.figure() # Crea una nueva figura para cada imagen
plt.imshow(x_test[random_index])
plt.title(f"Etiqueta real: {cifar10_classes[y_test[random_index][0]]} - Predicción: {cifar10_classes[predicted_labels[random_index]]}")
plt.axis('off')
plt.show()
Visualizamos una predicción del modelo para comprobar su funcionamiento en la práctica.